diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 229856aa4366..000000000000 --- a/.flake8 +++ /dev/null @@ -1,22 +0,0 @@ -[flake8] -ignore = - ;W503 line break before binary operator - W503, - ;E203 whitespace before ':' - E203, - -; exclude file -exclude = - .tox, - .git, - __pycache__, - build, - dist, - *.pyc, - *.egg-info, - .cache, - .eggs - -max-line-length = 120 - -per-file-ignores = __init__.py:F401 diff --git a/.github/workflows/scripts/check_doc_i18n.py b/.github/workflows/scripts/check_doc_i18n.py index 1aa7283e9e52..1e7f0c33a785 100644 --- a/.github/workflows/scripts/check_doc_i18n.py +++ b/.github/workflows/scripts/check_doc_i18n.py @@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2): # If the corresponding item doesn't exist in the second directory, the directories are different if not os.path.exists(item_path2): - print(f'Found mismatch: {item_path1}, {item_path2}') + print(f"Found mismatch: {item_path1}, {item_path2}") return False # If the corresponding item is a directory, we compare the two directories recursively if os.path.isdir(item_path1) and os.path.isdir(item_path2): if not compare_dirs(item_path1, item_path2): - print(f'Found mismatch: {item_path1}, {item_path2}') + print(f"Found mismatch: {item_path1}, {item_path2}") return False # both are files @@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2): # If the corresponding item is not a file or a directory, the directories are different else: - print(f'Found mismatch: {item_path1}, {item_path2}') + print(f"Found mismatch: {item_path1}, {item_path2}") return False # If all items are the same, the directories are the same return True -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.") + parser.add_argument("-d", "--directory", help="The directory where the multi-language source files are kept.") args = parser.parse_args() i18n_folders = os.listdir(args.directory) @@ -56,7 +56,7 @@ def compare_dirs(dir1, dir2): for i in range(1, len(i18n_folders)): dir1 = i18n_folders[0] dir2 = i18n_folders[i] - print(f'comparing {dir1} vs {dir2}') + print(f"comparing {dir1} vs {dir2}") match = compare_dirs(i18n_folders[0], i18n_folders[i]) if not match: diff --git a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py index 5bec96187e0c..91778f692cc6 100644 --- a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py +++ b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py @@ -4,7 +4,7 @@ def check_inputs(input_list): for path in input_list: - real_path = os.path.join('examples', path) + real_path = os.path.join("examples", path) if not os.path.exists(real_path): return False return True @@ -12,16 +12,16 @@ def check_inputs(input_list): def main(): parser = argparse.ArgumentParser() - parser.add_argument('-f', '--fileNameList', type=str, help="List of file names") + parser.add_argument("-f", "--fileNameList", type=str, help="List of file names") args = parser.parse_args() name_list = args.fileNameList.split(",") is_correct = check_inputs(name_list) if is_correct: - print('success') + print("success") else: - print('failure') + print("failure") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.github/workflows/scripts/example_checks/check_example_weekly.py b/.github/workflows/scripts/example_checks/check_example_weekly.py index 83eff644e315..95a3d24c9a78 100644 --- a/.github/workflows/scripts/example_checks/check_example_weekly.py +++ b/.github/workflows/scripts/example_checks/check_example_weekly.py @@ -17,21 +17,21 @@ def show_files(path, all_files): def join(input_list, sep=None): - return (sep or ' ').join(input_list) + return (sep or " ").join(input_list) def main(): - contents = show_files('examples/', []) + contents = show_files("examples/", []) all_loc = [] for file_loc in contents: - split_loc = file_loc.split('/') + split_loc = file_loc.split("/") # must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not. if len(split_loc) >= 4: - re_loc = '/'.join(split_loc[1:3]) + re_loc = "/".join(split_loc[1:3]) if re_loc not in all_loc: all_loc.append(re_loc) print(all_loc) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.github/workflows/scripts/example_checks/detect_changed_example.py b/.github/workflows/scripts/example_checks/detect_changed_example.py index c69d95a552e9..95f671dfb32b 100644 --- a/.github/workflows/scripts/example_checks/detect_changed_example.py +++ b/.github/workflows/scripts/example_checks/detect_changed_example.py @@ -3,7 +3,7 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files") + parser.add_argument("-f", "--fileNameList", type=str, help="The list of changed files") args = parser.parse_args() name_list = args.fileNameList.split(":") folder_need_check = set() @@ -15,10 +15,10 @@ def main(): # - application # - file if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4: - folder_need_check.add('/'.join(loc.split("/")[1:3])) + folder_need_check.add("/".join(loc.split("/")[1:3])) # Output the result using print. Then the shell can get the values. print(list(folder_need_check)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py index 2884e38dd3dd..412b14c7b283 100644 --- a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py +++ b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py @@ -74,16 +74,16 @@ def get_organization_repositories(github_token, organization_name) -> List[str]: # prepare header headers = { - 'Authorization': f'Bearer {github_token}', - 'Accept': 'application/vnd.github+json', - 'X-GitHub-Api-Version': '2022-11-28' + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", } res = requests.get(url, headers=headers).json() repo_list = [] for item in res: - repo_list.append(item['name']) + repo_list.append(item["name"]) return repo_list @@ -97,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name: """ # prepare header headers = { - 'Authorization': f'Bearer {github_token}', - 'Accept': 'application/vnd.github+json', - 'X-GitHub-Api-Version': '2022-11-28' + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", } user_engagement_count = {} @@ -107,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name: # do pagination to the API page = 1 while True: - comment_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}' + comment_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}" comment_response = requests.get(comment_api, headers=headers).json() if len(comment_response) == 0: break else: for item in comment_response: - comment_author_relationship = item['author_association'] - if comment_author_relationship != 'MEMBER': + comment_author_relationship = item["author_association"] + if comment_author_relationship != "MEMBER": # if the comment is not made by our member # we don't count this comment towards user engagement continue - issue_id = item['issue_url'].split('/')[-1] - issue_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}' + issue_id = item["issue_url"].split("/")[-1] + issue_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}" issue_response = requests.get(issue_api, headers=headers).json() - issue_author_relationship = issue_response['author_association'] + issue_author_relationship = issue_response["author_association"] - if issue_author_relationship != 'MEMBER': + if issue_author_relationship != "MEMBER": # this means that the issue/PR is not created by our own people # any comments in this issue/PR by our member will be counted towards the leaderboard - member_name = item['user']['login'] + member_name = item["user"]["login"] if member_name in user_engagement_count: user_engagement_count[member_name] += 1 @@ -153,7 +153,7 @@ def _generate_discussion_query(num, cursor: str = None): if cursor is None: offset_str = "" else: - offset_str = f", after: \"{cursor}\"" + offset_str = f', after: "{cursor}"' query = f""" {{ repository(owner: "{org_name}", name: "{repo_name}"){{ @@ -182,7 +182,7 @@ def _generate_comment_reply_count_for_discussion(discussion_number, num, cursor: if cursor is None: offset_str = "" else: - offset_str = f", before: \"{cursor}\"" + offset_str = f', before: "{cursor}"' query = f""" {{ repository(owner: "{org_name}", name: "{repo_name}"){{ @@ -220,8 +220,8 @@ def _generate_comment_reply_count_for_discussion(discussion_number, num, cursor: # a utility function to make call to Github GraphQL API def _call_graphql_api(query): headers = {"Authorization": f"Bearer {github_token}"} - json_data = {'query': query} - response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers) + json_data = {"query": query} + response = requests.post("https://api.github.com/graphql", json=json_data, headers=headers) data = response.json() return data @@ -234,21 +234,21 @@ def _call_graphql_api(query): data = _call_graphql_api(query) found_discussion_out_of_time_range = False - edges = data['data']['repository']['discussions']['edges'] + edges = data["data"]["repository"]["discussions"]["edges"] if len(edges) == 0: break else: # keep the discussion whose author is not a member for edge in edges: # print the discussion title - discussion = edge['node'] - discussion_updated_at = str2datetime(discussion['updatedAt']) + discussion = edge["node"] + discussion_updated_at = str2datetime(discussion["updatedAt"]) # check if the updatedAt is within the last 7 days # if yes, add it to discussion_numbers if discussion_updated_at > since: - if discussion['authorAssociation'] != 'MEMBER': - discussion_numbers.append(discussion['number']) + if discussion["authorAssociation"] != "MEMBER": + discussion_numbers.append(discussion["number"]) else: found_discussion_out_of_time_range = True @@ -256,7 +256,7 @@ def _call_graphql_api(query): break else: # update cursor - cursor = edges[-1]['cursor'] + cursor = edges[-1]["cursor"] # get the discussion comments and replies made by our member user_engagement_count = {} @@ -269,42 +269,42 @@ def _call_graphql_api(query): data = _call_graphql_api(query) # get the comments - edges = data['data']['repository']['discussion']['comments']['edges'] + edges = data["data"]["repository"]["discussion"]["comments"]["edges"] # update the cursor if len(edges) == 0: break else: # update cursor for pagination - cursor = edges[-1]['cursor'] + cursor = edges[-1]["cursor"] for edge in edges: - comment = edge['node'] - if comment['authorAssociation'] == 'MEMBER': + comment = edge["node"] + if comment["authorAssociation"] == "MEMBER": # check if the updatedAt is within the last 7 days # if yes, add it to user_engagement_count - comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") + comment_updated_at = datetime.strptime(comment["updatedAt"], "%Y-%m-%dT%H:%M:%SZ") if comment_updated_at > since: - member_name = comment['author']['login'] + member_name = comment["author"]["login"] if member_name in user_engagement_count: user_engagement_count[member_name] += 1 else: user_engagement_count[member_name] = 1 # get the replies - reply_edges = comment['replies']['edges'] + reply_edges = comment["replies"]["edges"] if len(reply_edges) == 0: continue else: for reply_edge in reply_edges: - reply = reply_edge['node'] - if reply['authorAssociation'] == 'MEMBER': + reply = reply_edge["node"] + if reply["authorAssociation"] == "MEMBER": # check if the updatedAt is within the last 7 days # if yes, add it to discussion_numbers - reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") + reply_updated_at = datetime.strptime(reply["updatedAt"], "%Y-%m-%dT%H:%M:%SZ") if reply_updated_at > since: - member_name = reply['author']['login'] + member_name = reply["author"]["login"] if member_name in user_engagement_count: user_engagement_count[member_name] += 1 else: @@ -312,7 +312,9 @@ def _call_graphql_api(query): return user_engagement_count -def generate_user_engagement_leaderboard_image(github_token: str, org_name: str, repo_list: List[str], output_path: str) -> bool: +def generate_user_engagement_leaderboard_image( + github_token: str, org_name: str, repo_list: List[str], output_path: str +) -> bool: """ Generate the user engagement leaderboard image for stats within the last 7 days @@ -335,16 +337,19 @@ def _update_count(counter): else: total_engagement_count[name] = count - for repo_name in repo_list: print(f"Fetching user engagement count for {repo_name}/{repo_name}") - issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str) - discussion_engagement_count = get_discussion_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime) + issue_pr_engagement_count = get_issue_pull_request_comments( + github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str + ) + discussion_engagement_count = get_discussion_comments( + github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime + ) # update the total engagement count _update_count(issue_pr_engagement_count) _update_count(discussion_engagement_count) - + # prepare the data for plotting x = [] y = [] @@ -363,7 +368,7 @@ def _update_count(counter): # plot the leaderboard xlabel = f"Number of Comments made (since {start_datetime_str})" ylabel = "Member" - title = 'Active User Engagement Leaderboard' + title = "Active User Engagement Leaderboard" plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) return True else: @@ -380,16 +385,16 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou """ # request to the Github API to get the users who have contributed in the last 7 days headers = { - 'Authorization': f'Bearer {github_token}', - 'Accept': 'application/vnd.github+json', - 'X-GitHub-Api-Version': '2022-11-28' + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", } counter = Counter() start_datetime = get_utc_time_one_week_ago() def _get_url(org_name, repo_name, page): - return f'https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed' + return f"https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed" def _iterate_by_page(org_name, repo_name): page = 1 @@ -415,8 +420,8 @@ def _iterate_by_page(org_name, repo_name): # count the pull request and author from response for pr_data in response: - merged_at = pr_data['merged_at'] - author = pr_data['user']['login'] + merged_at = pr_data["merged_at"] + author = pr_data["user"]["login"] if merged_at is None: continue @@ -439,7 +444,7 @@ def _iterate_by_page(org_name, repo_name): _iterate_by_page(org_name, repo_name) # convert unix timestamp to Beijing datetime - bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone('Asia/Shanghai')) + bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone("Asia/Shanghai")) bj_start_datetime_str = datetime2str(bj_start_datetime) contribution_list = counter.to_sorted_list() @@ -452,7 +457,7 @@ def _iterate_by_page(org_name, repo_name): if len(author_list) > 0: xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})" ylabel = "Contributor" - title = 'Active Contributor Leaderboard' + title = "Active Contributor Leaderboard" plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) return True else: @@ -468,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str: image_path (str): the path to the image to be uploaded """ url = "https://open.feishu.cn/open-apis/im/v1/images" - form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path + form = {"image_type": "message", "image": (open(image_path, "rb"))} # 需要替换具体的path multi_form = MultipartEncoder(form) headers = { - 'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token + "Authorization": f"Bearer {lark_tenant_token}", ## 获取tenant_access_token, 需要替换为实际的token } - headers['Content-Type'] = multi_form.content_type + headers["Content-Type"] = multi_form.content_type response = requests.request("POST", url, headers=headers, data=multi_form).json() - return response['data']['image_key'] + return response["data"]["image_key"] def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str: @@ -486,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str: app_id (str): Lark app id app_secret (str): Lark app secret """ - url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal' - data = {'app_id': app_id, 'app_secret': app_secret} + url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" + data = {"app_id": app_id, "app_secret": app_secret} response = requests.post(url, json=data).json() - return response['tenant_access_token'] + return response["tenant_access_token"] def send_image_to_lark(image_key: str, webhook_url: str) -> None: @@ -516,10 +521,10 @@ def send_message_to_lark(message: str, webhook_url: str): requests.post(webhook_url, json=data) -if __name__ == '__main__': - GITHUB_TOKEN = os.environ['GITHUB_TOKEN'] - CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png' - USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png' +if __name__ == "__main__": + GITHUB_TOKEN = os.environ["GITHUB_TOKEN"] + CONTRIBUTOR_IMAGE_PATH = "contributor_leaderboard.png" + USER_ENGAGEMENT_IMAGE_PATH = "engagement_leaderboard.png" ORG_NAME = "hpcaitech" # get all open source repositories @@ -527,17 +532,19 @@ def send_message_to_lark(message: str, webhook_url: str): # generate images contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH) - engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH) + engagement_success = generate_user_engagement_leaderboard_image( + GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH + ) # upload images - APP_ID = os.environ['LARK_APP_ID'] - APP_SECRET = os.environ['LARK_APP_SECRET'] + APP_ID = os.environ["LARK_APP_ID"] + APP_SECRET = os.environ["LARK_APP_SECRET"] LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET) contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH) user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH) # send message to lark - LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL'] + LARK_WEBHOOK_URL = os.environ["LARK_WEBHOOK_URL"] message = """本周的社区榜单出炉啦! 1. 开发贡献者榜单 2. 用户互动榜单 diff --git a/.github/workflows/scripts/generate_release_draft.py b/.github/workflows/scripts/generate_release_draft.py index dc592e4c977b..7374481005ef 100644 --- a/.github/workflows/scripts/generate_release_draft.py +++ b/.github/workflows/scripts/generate_release_draft.py @@ -7,27 +7,27 @@ import requests -COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits' -TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags' +COMMIT_API = "https://api.github.com/repos/hpcaitech/ColossalAI/commits" +TAGS_API = "https://api.github.com/repos/hpcaitech/ColossalAI/tags" def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--out', type=str, help='output path for the release draft', required=True) - parser.add_argument('--version', type=str, help='current version to release', required=True) + parser.add_argument("--out", type=str, help="output path for the release draft", required=True) + parser.add_argument("--version", type=str, help="current version to release", required=True) return parser.parse_args() def get_latest_tag_commit(headers=None): res = requests.get(url=TAGS_API, headers=headers) data = res.json() - commit_hash = data[0]['commit']['sha'] - version = data[0]['name'] + commit_hash = data[0]["commit"]["sha"] + version = data[0]["name"] return commit_hash, version def get_commit_info(commit_hash, headers=None): - api = f'{COMMIT_API}/{commit_hash}' + api = f"{COMMIT_API}/{commit_hash}" res = requests.get(url=api, headers=headers) return res.json() @@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None): results = [] while True: - api = f'{COMMIT_API}?since={since}&per_page=100&page={page}' + api = f"{COMMIT_API}?since={since}&per_page=100&page={page}" resp = requests.get(url=api, headers=headers) data = resp.json() @@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None): def collate_release_info(commit_info_list): results = dict() - pattern = pattern = r'\[.*\]' + pattern = pattern = r"\[.*\]" for commit_info in commit_info_list: - author = commit_info['commit']['author']['name'] + author = commit_info["commit"]["author"]["name"] try: - author_url = commit_info['author']['url'] + author_url = commit_info["author"]["url"] except: # author can be None author_url = None - msg = commit_info['commit']['message'] + msg = commit_info["commit"]["message"] match = re.search(pattern, msg) if match: - tag = match.group().lstrip('[').rstrip(']').capitalize() + tag = match.group().lstrip("[").rstrip("]").capitalize() if tag not in results: results[tag] = [] results[tag].append((msg, author, author_url)) @@ -89,42 +89,43 @@ def generate_release_post_markdown(current_version, last_version, release_info): for msg, author, author_url in v: # only keep the first line - msg = msg.split('\n')[0] + msg = msg.split("\n")[0] if author_url: - item = f'{msg} by [{author}]({author_url})\n' + item = f"{msg} by [{author}]({author_url})\n" else: - item = f'{msg} by {author}\n' - text.append(f'- {item}') + item = f"{msg} by {author}\n" + text.append(f"- {item}") - text.append('\n') + text.append("\n") # add full change log text.append( - f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}') + f"**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}" + ) return text -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() - token = os.environ['GITHUB_API_TOKEN'] - headers = {'Authorization': token} + token = os.environ["GITHUB_API_TOKEN"] + headers = {"Authorization": token} # get previous release tag last_release_commit, last_version = get_latest_tag_commit(headers) last_release_commit_info = get_commit_info(last_release_commit, headers=headers) - last_release_date = last_release_commit_info['commit']['author']['date'] + last_release_date = last_release_commit_info["commit"]["author"]["date"] # get the commits since last release commit_info = get_all_commit_info(since=last_release_date, headers=headers) - commit_info = commit_info[:-1] # remove the release commit + commit_info = commit_info[:-1] # remove the release commit # collate into markdown release_info = collate_release_info(commit_info) markdown_text = generate_release_post_markdown(args.version, last_version, release_info) # write into a file - with open(args.out, 'w') as f: + with open(args.out, "w") as f: for line in markdown_text: f.write(line) diff --git a/.github/workflows/scripts/send_message_to_lark.py b/.github/workflows/scripts/send_message_to_lark.py index a113327a786e..bc005d93c3f5 100644 --- a/.github/workflows/scripts/send_message_to_lark.py +++ b/.github/workflows/scripts/send_message_to_lark.py @@ -5,8 +5,8 @@ def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('-m', '--message', type=str) - parser.add_argument('-u', '--url', type=str) + parser.add_argument("-m", "--message", type=str) + parser.add_argument("-u", "--url", type=str) return parser.parse_args() @@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url): requests.post(webhook_url, json=data) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() send_message_to_lark(args.message, args.url) diff --git a/.isort.cfg b/.isort.cfg index 090aa28e39f3..4f881c8b3dda 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,3 +3,4 @@ line_length = 120 multi_line_output=3 include_trailing_comma = true ignore_comments = true +profile = black diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 725d266375ef..9871e1184462 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,31 @@ repos: + - repo: https://github.com/PyCQA/autoflake + rev: v2.2.1 + hooks: + - id: autoflake + name: autoflake (python) + args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] + - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort name: sort all imports (python) - - repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.32.0 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.9.1 hooks: - - id: yapf - name: yapf formatter - args: ['--style=.style.yapf', '--parallel', '--in-place'] + - id: black + name: black formatter + args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] - repo: https://github.com/pre-commit/mirrors-clang-format rev: v13.0.1 hooks: - id: clang-format name: clang formatter + types_or: [c++, c] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 05be0dc6a3a5..000000000000 --- a/.style.yapf +++ /dev/null @@ -1,5 +0,0 @@ -[style] -based_on_style = google -spaces_before_comment = 4 -split_before_logical_operator = true -column_limit = 120 diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index 90471ed727b0..04f779821405 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int: def preprocess_batch(samples) -> dict: input_ids = torch.stack(samples) attention_mask = torch.ones_like(input_ids, dtype=torch.long) - return {'input_ids': input_ids, 'attention_mask': attention_mask} + return {"input_ids": input_ids, "attention_mask": attention_mask} def print_rank_0(*args, **kwargs) -> None: @@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None: B = 1024**3 M = 1024**2 K = 1024 - outputs = '' + outputs = "" for name, numel in model_dict.items(): - outputs += f'{name}: ' + outputs += f"{name}: " if numel >= B: - outputs += f'{numel / B:.2f} B\n' + outputs += f"{numel / B:.2f} B\n" elif numel >= M: - outputs += f'{numel / M:.2f} M\n' + outputs += f"{numel / M:.2f} M\n" elif numel >= K: - outputs += f'{numel / K:.2f} K\n' + outputs += f"{numel / K:.2f} K\n" else: - outputs += f'{numel}\n' + outputs += f"{numel}\n" print_rank_0(outputs) def get_gpt_config(model_name: str) -> OPTConfig: model_map = { - '125m': OPTConfig.from_pretrained('facebook/opt-125m'), - '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), - '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), - '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'), - '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'), - '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), - '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), - '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'), - '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), - '13b': OPTConfig.from_pretrained('facebook/opt-13b'), + "125m": OPTConfig.from_pretrained("facebook/opt-125m"), + "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), + "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), + "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"), + "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"), + "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), + "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), + "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"), + "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), + "13b": OPTConfig.from_pretrained("facebook/opt-13b"), } try: return model_map[model_name] @@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig: def main(args): - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) - elif args.strategy == 'colossalai_gemini_cpu': - strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2_cpu': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu') - elif args.strategy == 'colossalai_zero1': - strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda') - elif args.strategy == 'colossalai_zero1_cpu': - strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif args.strategy == "colossalai_gemini_cpu": + strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5) + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") + elif args.strategy == "colossalai_zero2_cpu": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") + elif args.strategy == "colossalai_zero1": + strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda") + elif args.strategy == "colossalai_zero1_cpu": + strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') @@ -103,90 +103,106 @@ def main(args): if args.use_kernels: from coati.kernels import convert_to_xformer_model - actor, critic, initial_model, reward_model = map(convert_to_xformer_model, - (actor, critic, initial_model, reward_model)) + + actor, critic, initial_model, reward_model = map( + convert_to_xformer_model, (actor, critic, initial_model, reward_model) + ) actor_numel = get_model_numel(actor, strategy) critic_numel = get_model_numel(critic, strategy) initial_model_numel = get_model_numel(initial_model, strategy) reward_model_numel = get_model_numel(reward_model, strategy) - print_model_numel({ - 'Actor': actor_numel, - 'Critic': critic_numel, - 'Initial model': initial_model_numel, - 'Reward model': reward_model_numel - }) - performance_evaluator = PerformanceEvaluator(actor_numel, - critic_numel, - initial_model_numel, - reward_model_numel, - enable_grad_checkpoint=False, - ignore_episodes=1) - - if args.strategy.startswith('colossalai'): + print_model_numel( + { + "Actor": actor_numel, + "Critic": critic_numel, + "Initial model": initial_model_numel, + "Reward model": reward_model_numel, + } + ) + performance_evaluator = PerformanceEvaluator( + actor_numel, + critic_numel, + initial_model_numel, + reward_model_numel, + enable_grad_checkpoint=False, + ignore_episodes=1, + ) + + if args.strategy.startswith("colossalai"): actor_optim = HybridAdam(actor.parameters(), lr=5e-6) critic_optim = HybridAdam(critic.parameters(), lr=5e-6) else: actor_optim = Adam(actor.parameters(), lr=5e-6) critic_optim = Adam(critic.parameters(), lr=5e-6) - tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer.pad_token = tokenizer.eos_token (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) - dataloader = DataLoader(random_prompts, - batch_size=args.experience_batch_size, - shuffle=True, - collate_fn=preprocess_batch) - - trainer = PPOTrainer(strategy, - actor, - critic, - reward_model, - initial_model, - actor_optim, - critic_optim, - ptx_coef=0, - train_batch_size=args.train_batch_size, - offload_inference_models=args.offload_inference_models, - max_length=512, - do_sample=True, - temperature=1.0, - top_k=50, - use_cache=True, - pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id, - callbacks=[performance_evaluator]) - - trainer.fit(prompt_dataloader=dataloader, - pretrain_dataloader=None, - num_episodes=args.num_episodes, - num_update_steps=args.num_update_steps, - num_collect_steps=args.num_collect_steps) - - print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') - - -if __name__ == '__main__': + dataloader = DataLoader( + random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch + ) + + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + ptx_coef=0, + train_batch_size=args.train_batch_size, + offload_inference_models=args.offload_inference_models, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + use_cache=True, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + callbacks=[performance_evaluator], + ) + + trainer.fit( + prompt_dataloader=dataloader, + pretrain_dataloader=None, + num_episodes=args.num_episodes, + num_update_steps=args.num_update_steps, + num_collect_steps=args.num_collect_steps, + ) + + print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB") + + +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='125m') - parser.add_argument('--critic_model', default='125m') - parser.add_argument('--strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', - 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' - ], - default='ddp') - parser.add_argument('--num_episodes', type=int, default=3) - parser.add_argument('--num_collect_steps', type=int, default=8) - parser.add_argument('--num_update_steps', type=int, default=1) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0) - parser.add_argument('--cuda_mem_frac', type=float, default=1.0) - parser.add_argument('--offload_inference_models', action='store_true', default=False) - parser.add_argument('--use_kernels', action='store_true', default=False) + parser.add_argument("--model", default="125m") + parser.add_argument("--critic_model", default="125m") + parser.add_argument( + "--strategy", + choices=[ + "ddp", + "colossalai_gemini", + "colossalai_gemini_cpu", + "colossalai_zero2", + "colossalai_zero2_cpu", + "colossalai_zero1", + "colossalai_zero1_cpu", + ], + default="ddp", + ) + parser.add_argument("--num_episodes", type=int, default=3) + parser.add_argument("--num_collect_steps", type=int, default=8) + parser.add_argument("--num_update_steps", type=int, default=1) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--cuda_mem_frac", type=float, default=1.0) + parser.add_argument("--offload_inference_models", action="store_true", default=False) + parser.add_argument("--use_kernels", action="store_true", default=False) args = parser.parse_args() main(args) diff --git a/applications/Chat/benchmarks/ray/1mmt_dummy.py b/applications/Chat/benchmarks/ray/1mmt_dummy.py index 7fc990448805..98ace3869450 100644 --- a/applications/Chat/benchmarks/ray/1mmt_dummy.py +++ b/applications/Chat/benchmarks/ray/1mmt_dummy.py @@ -22,13 +22,13 @@ def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] def get_local_ip(): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) + s.connect(("8.8.8.8", 80)) return s.getsockname()[0] @@ -36,22 +36,25 @@ def main(args): master_addr = str(get_local_ip()) # trainer_env_info trainer_port = str(get_free_port()) - env_info_trainers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_trainers), - 'master_port': trainer_port, - 'master_addr': master_addr - } for rank in range(args.num_trainers)] + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] # maker_env_info maker_port = str(get_free_port()) env_info_maker = { - 'local_rank': '0', - 'rank': '0', - 'world_size': '1', - 'master_port': maker_port, - 'master_addr': master_addr + "local_rank": "0", + "rank": "0", + "world_size": "1", + "master_port": maker_port, + "master_addr": master_addr, } # configure tokenizer @@ -63,21 +66,27 @@ def model_fn(): critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() - reward_model = get_reward_model_from_args(args.critic_model, - config=critic_cfg).requires_grad_(False).half().cuda() - if args.initial_model_quant_ckpt is not None and args.model == 'llama': + reward_model = ( + get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() + ) + if args.initial_model_quant_ckpt is not None and args.model == "llama": # quantize initial model with low_resource_init(), no_init_weights(): initial_model = get_actor_from_args(args.model, config=actor_cfg) - initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, - args.quant_group_size).cuda().requires_grad_(False) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) else: initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() return actor, critic, reward_model, initial_model # configure Experience Maker experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)], + detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)], strategy_fn=partial(get_strategy_from_args, args.maker_strategy), model_fn=model_fn, env_info=env_info_maker, @@ -97,15 +106,18 @@ def model_fn(): def trainer_model_fn(): actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() - critic = get_critic_from_args(args.critic_model, - config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda() + critic = ( + get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain)) + .half() + .cuda() + ) return actor, critic # configure Trainer trainer_refs = [ DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( experience_maker_holder_name_list=[ - f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True) + f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True) ], strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), model_fn=trainer_model_fn, @@ -114,7 +126,8 @@ def trainer_model_fn(): buffer_limit=16, eval_performance=True, debug=args.debug, - ) for i, env_info_trainer in enumerate(env_info_trainers) + ) + for i, env_info_trainer in enumerate(env_info_trainers) ] dataset_size = args.experience_batch_size * 4 @@ -122,7 +135,7 @@ def trainer_model_fn(): def data_gen_fn(): input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) attn_mask = torch.ones_like(input_ids) - return {'input_ids': input_ids, 'attention_mask': attn_mask} + return {"input_ids": input_ids, "attention_mask": attn_mask} def build_dataloader(size): dataset = [data_gen_fn() for _ in range(size)] @@ -138,8 +151,10 @@ def build_dataloader(size): wait_tasks = [] wait_tasks.append( - experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size), - num_steps=args.experience_steps)) + experience_holder_ref.workingloop.remote( + partial(build_dataloader, dataset_size), num_steps=args.experience_steps + ) + ) total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size) for trainer_ref in trainer_refs: @@ -148,31 +163,30 @@ def build_dataloader(size): ray.get(wait_tasks) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--num_trainers', type=int, default=1) - parser.add_argument('--trainer_strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', - 'colossalai_zero2_cpu' - ], - default='ddp') - parser.add_argument('--maker_strategy', choices=['naive'], default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--critic_pretrain', type=str, default=None) - parser.add_argument('--experience_steps', type=int, default=4) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--train_epochs', type=int, default=1) - parser.add_argument('--update_steps', type=int, default=2) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) - parser.add_argument('--quant_bits', type=int, default=4) - parser.add_argument('--quant_group_size', type=int, default=128) - parser.add_argument('--debug', action='store_true') + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") args = parser.parse_args() ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) main(args) diff --git a/applications/Chat/benchmarks/ray/mmmt_dummy.py b/applications/Chat/benchmarks/ray/mmmt_dummy.py index ca1df22070fc..f8860f2979ee 100644 --- a/applications/Chat/benchmarks/ray/mmmt_dummy.py +++ b/applications/Chat/benchmarks/ray/mmmt_dummy.py @@ -22,13 +22,13 @@ def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] def get_local_ip(): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) + s.connect(("8.8.8.8", 80)) return s.getsockname()[0] @@ -36,23 +36,29 @@ def main(args): master_addr = str(get_local_ip()) # trainer_env_info trainer_port = str(get_free_port()) - env_info_trainers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_trainers), - 'master_port': trainer_port, - 'master_addr': master_addr - } for rank in range(args.num_trainers)] + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] # maker_env_info maker_port = str(get_free_port()) - env_info_makers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_makers), - 'master_port': maker_port, - 'master_addr': master_addr - } for rank in range(args.num_makers)] + env_info_makers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_makers), + "master_port": maker_port, + "master_addr": master_addr, + } + for rank in range(args.num_makers) + ] # configure tokenizer tokenizer = AutoTokenizer.from_pretrained(args.pretrain) @@ -63,14 +69,20 @@ def model_fn(): critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() - reward_model = get_reward_model_from_args(args.critic_model, - config=critic_cfg).requires_grad_(False).half().cuda() - if args.initial_model_quant_ckpt is not None and args.model == 'llama': + reward_model = ( + get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() + ) + if args.initial_model_quant_ckpt is not None and args.model == "llama": # quantize initial model with low_resource_init(), no_init_weights(): initial_model = get_actor_from_args(args.model, config=actor_cfg) - initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, - args.quant_group_size).cuda().requires_grad_(False) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) else: initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() return actor, critic, reward_model, initial_model @@ -79,7 +91,7 @@ def model_fn(): experience_holder_refs = [ ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote( detached_trainer_name_list=[ - f'trainer{x}' + f"trainer{x}" for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False) ], strategy_fn=partial(get_strategy_from_args, args.maker_strategy), @@ -103,8 +115,11 @@ def model_fn(): def trainer_model_fn(): actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() - critic = get_critic_from_args(args.critic_model, - config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda() + critic = ( + get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain)) + .half() + .cuda() + ) return actor, critic # configure Trainer @@ -130,7 +145,7 @@ def trainer_model_fn(): def data_gen_fn(): input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) attn_mask = torch.ones_like(input_ids) - return {'input_ids': input_ids, 'attention_mask': attn_mask} + return {"input_ids": input_ids, "attention_mask": attn_mask} def build_dataloader(size): dataset = [data_gen_fn() for _ in range(size)] @@ -147,43 +162,48 @@ def build_dataloader(size): for experience_holder_ref in experience_holder_refs: wait_tasks.append( - experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size), - num_steps=args.experience_steps)) + experience_holder_ref.workingloop.remote( + partial(build_dataloader, dataset_size), num_steps=args.experience_steps + ) + ) - total_steps = args.experience_batch_size * args.experience_steps * \ - args.num_makers // (args.num_trainers * args.train_batch_size) + total_steps = ( + args.experience_batch_size + * args.experience_steps + * args.num_makers + // (args.num_trainers * args.train_batch_size) + ) for trainer_ref in trainer_refs: wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) ray.get(wait_tasks) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--num_makers', type=int, default=1) - parser.add_argument('--num_trainers', type=int, default=1) - parser.add_argument('--trainer_strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', - 'colossalai_zero2_cpu' - ], - default='ddp') - parser.add_argument('--maker_strategy', choices=['naive'], default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--critic_pretrain', type=str, default=None) - parser.add_argument('--experience_steps', type=int, default=4) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--train_epochs', type=int, default=1) - parser.add_argument('--update_steps', type=int, default=2) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) - parser.add_argument('--quant_bits', type=int, default=4) - parser.add_argument('--quant_group_size', type=int, default=128) - parser.add_argument('--debug', action='store_true') + parser.add_argument("--num_makers", type=int, default=1) + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") args = parser.parse_args() ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) main(args) diff --git a/applications/Chat/coati/dataset/__init__.py b/applications/Chat/coati/dataset/__init__.py index bd4e5460d11e..599b57609775 100644 --- a/applications/Chat/coati/dataset/__init__.py +++ b/applications/Chat/coati/dataset/__init__.py @@ -4,7 +4,10 @@ from .utils import is_rank_0 __all__ = [ - 'RmStaticDataset', 'HhRlhfDataset', - 'SFTDataset', 'SupervisedDataset', - 'PromptDataset', 'is_rank_0', + "RmStaticDataset", + "HhRlhfDataset", + "SFTDataset", + "SupervisedDataset", + "PromptDataset", + "is_rank_0", ] diff --git a/applications/Chat/coati/dataset/conversation.py b/applications/Chat/coati/dataset/conversation.py index 465fa867c7ab..f2180d96b0d3 100644 --- a/applications/Chat/coati/dataset/conversation.py +++ b/applications/Chat/coati/dataset/conversation.py @@ -49,7 +49,7 @@ def append_message(self, role, message): def to_gradio_chatbot(self): ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset:]): + for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append([msg, None]) else: @@ -57,12 +57,14 @@ def to_gradio_chatbot(self): return ret def copy(self): - return Conversation(system=self.system, - roles=self.roles, - messages=[[x, y] for x, y in self.messages], - offset=self.offset, - sep_style=self.sep_style, - sep=self.sep) + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + ) def dict(self): return { @@ -70,7 +72,7 @@ def dict(self): "roles": self.roles, "messages": self.messages, "offset": self.offset, - "sep": self.sep + "sep": self.sep, } diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py index 2c953fffa513..17120e6064b5 100644 --- a/applications/Chat/coati/dataset/prompt_dataset.py +++ b/applications/Chat/coati/dataset/prompt_dataset.py @@ -13,11 +13,13 @@ class PromptDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, - data_path: str, - tokenizer: transformers.PreTrainedTokenizer, - max_datasets_size: int = None, - max_length: int = 96): + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + max_datasets_size: int = None, + max_length: int = 96, + ): super(PromptDataset, self).__init__() self.keyed_prompt = defaultdict(list) self.logger = get_dist_logger() @@ -30,11 +32,9 @@ def __init__(self, list_data_dict = list_data_dict[:max_datasets_size] instructions = [data_dict["instruction"] for data_dict in list_data_dict] - tokens = tokenizer(instructions, - return_tensors='pt', - max_length=max_length, - padding='max_length', - truncation=True) + tokens = tokenizer( + instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True + ) for k, tensor in tokens.items(): self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind() diff --git a/applications/Chat/coati/dataset/reward_dataset.py b/applications/Chat/coati/dataset/reward_dataset.py index 3c4ec8b214bb..3afcd7b69238 100644 --- a/applications/Chat/coati/dataset/reward_dataset.py +++ b/applications/Chat/coati/dataset/reward_dataset.py @@ -20,44 +20,31 @@ class RmStaticDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() - self.end_token = tokenizer.eos_token \ - if special_token is None else special_token - - chosen = [ - data["prompt"] + data["chosen"] + self.end_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen = { - "input_ids": chosen_token["input_ids"], - "attention_mask": chosen_token["attention_mask"] - } - - reject = [ - data["prompt"] + data["rejected"] + self.end_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject = { - "input_ids": reject_token["input_ids"], - "attention_mask": reject_token["attention_mask"] - } + self.end_token = tokenizer.eos_token if special_token is None else special_token + + chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + chosen_token = tokenizer( + chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} + + reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + reject_token = tokenizer( + reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} def __len__(self): length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): - return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \ - self.reject["input_ids"][idx], self.reject["attention_mask"][idx] + return ( + self.chosen["input_ids"][idx], + self.chosen["attention_mask"][idx], + self.reject["input_ids"][idx], + self.reject["attention_mask"][idx], + ) # Anthropic/hh-rlhf @@ -74,41 +61,28 @@ class HhRlhfDataset(Dataset): def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() - self.end_token = tokenizer.eos_token \ - if special_token is None else special_token - - chosen = [ - data["chosen"] + self.end_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen = { - "input_ids": chosen_token["input_ids"], - "attention_mask": chosen_token["attention_mask"] - } - - reject = [ - data["rejected"] + self.end_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject = { - "input_ids": reject_token["input_ids"], - "attention_mask": reject_token["attention_mask"] - } + self.end_token = tokenizer.eos_token if special_token is None else special_token + + chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + chosen_token = tokenizer( + chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} + + reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())] + reject_token = tokenizer( + reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} def __len__(self): length = self.chosen["input_ids"].shape[0] return length def __getitem__(self, idx): - return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \ - self.reject["input_ids"][idx], self.reject["attention_mask"][idx] + return ( + self.chosen["input_ids"][idx], + self.chosen["attention_mask"][idx], + self.reject["input_ids"][idx], + self.reject["attention_mask"][idx], + ) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index 2959d3fac81c..d6be09ca5cc9 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -16,10 +16,11 @@ from typing import Dict, Sequence, Tuple import torch +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from torch.utils.data import Dataset from tqdm import tqdm from transformers import PreTrainedTokenizer -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer + from colossalai.logging import get_dist_logger from .utils import is_rank_0, jload @@ -28,32 +29,33 @@ IGNORE_INDEX = -100 PROMPT_DICT = { - "prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), - "prompt_no_input": ("Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:"), + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), } -def _preprocess(sources: Sequence[str], - targets: Sequence[str], - tokenizer: PreTrainedTokenizer, - max_length: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def _preprocess( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: PreTrainedTokenizer, + max_length: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Preprocess the data by tokenizing.""" sequences = [s + t for s, t in zip(sources, targets)] - sequences_token = tokenizer(sequences, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - sources_token = tokenizer(sources, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") + sequences_token = tokenizer( + sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + sources_token = tokenizer( + sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) labels = copy.deepcopy(sequences_token["input_ids"]) for i in range(labels.shape[0]): @@ -64,23 +66,24 @@ def _preprocess(sources: Sequence[str], labels[i][:source_len] = IGNORE_INDEX elif tokenizer.padding_side == "left": # |pad|prompt|completion|eos| - labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX + labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX else: raise RuntimeError() return sequences_token["input_ids"], labels, sequences_token["attention_mask"] -def _preprocess_chatglm(sources: Sequence[str], - targets: Sequence[str], - tokenizer: PreTrainedTokenizer, - max_length: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def _preprocess_chatglm( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: PreTrainedTokenizer, + max_length: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Preprocess the data by tokenizing. None for attention mask, ChatGLM will calculate attention mask according to input ids """ - + labels = [] input_ids = [] for source, target in zip(sources, targets): @@ -90,16 +93,16 @@ def _preprocess_chatglm(sources: Sequence[str], # truncate sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id] truncate_length = max(0, len(input_id) - max_length) - input_id = input_id[truncate_length: ] + input_id = input_id[truncate_length:] if truncate_length == len(source_id) + 1: - input_id = sp_token_list + input_id[1: ] + input_id = sp_token_list + input_id[1:] elif truncate_length > len(source_id) + 1: - input_id = sp_token_list + input_id[2: ] - + input_id = sp_token_list + input_id[2:] + context_length = input_id.index(tokenizer.bos_token_id) mask_position = context_length - 1 - label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:] - + label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :] + pad_len = max_length - len(input_id) input_id = input_id + [tokenizer.pad_token_id] * pad_len input_ids.append(input_id) @@ -117,25 +120,18 @@ class SFTDataset(Dataset): max_length: max length of input """ - def __init__(self, - dataset: Dict, - tokenizer: PreTrainedTokenizer, - max_length: int = 512 - ) -> None: + def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None: super().__init__() self.input_ids = [] sources = [data["prompt"] for data in dataset] - targets = [ - data["completion"] + tokenizer.eos_token - for data in tqdm(dataset, disable=not is_rank_0()) - ] + targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())] if isinstance(tokenizer, ChatGLMTokenizer): - self.input_ids, self.labels, self.attention_mask = \ - _preprocess_chatglm(sources, targets, tokenizer, max_length) + self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( + sources, targets, tokenizer, max_length + ) else: - self.input_ids, self.labels, self.attention_mask = \ - _preprocess(sources, targets, tokenizer, max_length) + self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) def __len__(self): length = self.input_ids.shape[0] @@ -143,22 +139,17 @@ def __len__(self): def __getitem__(self, idx): if self.attention_mask is not None: - return dict(input_ids=self.input_ids[idx], - labels=self.labels[idx], - attention_mask=self.attention_mask[idx]) + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) else: - return dict(input_ids=self.input_ids[idx], - labels=self.labels[idx]) + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, - data_path: str, - tokenizer: PreTrainedTokenizer, - max_datasets_size: int = None, - max_length: int = 512): + def __init__( + self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512 + ): super().__init__() logger.info("Loading data...") list_data_dict = jload(data_path) @@ -174,18 +165,15 @@ def __init__(self, prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example) for example in list_data_dict ] - targets = [ - example['output'] + tokenizer.eos_token - for example in list_data_dict - ] + targets = [example["output"] + tokenizer.eos_token for example in list_data_dict] logger.info("Tokenizing inputs... This may take some time...") if isinstance(tokenizer, ChatGLMTokenizer): - self.input_ids, self.labels, self.attention_mask = \ - _preprocess_chatglm(sources, targets, tokenizer, max_length) + self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm( + sources, targets, tokenizer, max_length + ) else: - self.input_ids, self.labels, self.attention_mask = \ - _preprocess(sources, targets, tokenizer, max_length) + self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length) def __len__(self): length = self.input_ids.shape[0] @@ -193,9 +181,6 @@ def __len__(self): def __getitem__(self, idx): if self.attention_mask is not None: - return dict(input_ids=self.input_ids[idx], - labels=self.labels[idx], - attention_mask=self.attention_mask[idx]) + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) else: - return dict(input_ids=self.input_ids[idx], - labels=self.labels[idx]) + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) diff --git a/applications/Chat/coati/experience_buffer/__init__.py b/applications/Chat/coati/experience_buffer/__init__.py index c0188dc4a471..f2a48d0a3b20 100644 --- a/applications/Chat/coati/experience_buffer/__init__.py +++ b/applications/Chat/coati/experience_buffer/__init__.py @@ -1,4 +1,4 @@ from .base import ExperienceBuffer from .naive import NaiveExperienceBuffer -__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer'] +__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"] diff --git a/applications/Chat/coati/experience_buffer/base.py b/applications/Chat/coati/experience_buffer/base.py index 9ccdc935d506..7047785308f3 100644 --- a/applications/Chat/coati/experience_buffer/base.py +++ b/applications/Chat/coati/experience_buffer/base.py @@ -7,9 +7,9 @@ class ExperienceBuffer(ABC): """Experience buffer base class. It stores experience. - Args: - sample_batch_size (int): Batch size when sampling. - limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. """ def __init__(self, sample_batch_size: int, limit: int = 0) -> None: diff --git a/applications/Chat/coati/experience_buffer/naive.py b/applications/Chat/coati/experience_buffer/naive.py index bd5213b38993..acc0fbe88ab4 100644 --- a/applications/Chat/coati/experience_buffer/naive.py +++ b/applications/Chat/coati/experience_buffer/naive.py @@ -11,23 +11,23 @@ class NaiveExperienceBuffer(ExperienceBuffer): """Naive experience buffer class. It stores experience. - Args: - sample_batch_size (int): Batch size when sampling. - limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. - cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. + Args: + sample_batch_size (int): Batch size when sampling. + limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0. + cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True. """ def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None: super().__init__(sample_batch_size, limit) self.cpu_offload = cpu_offload - self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}') + self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}") # TODO(ver217): add prefetch self.items: List[BufferItem] = [] @torch.no_grad() def append(self, experience: Experience) -> None: if self.cpu_offload: - experience.to_device(torch.device('cpu')) + experience.to_device(torch.device("cpu")) items = split_experience_batch(experience) self.items.extend(items) if self.limit > 0: diff --git a/applications/Chat/coati/experience_buffer/utils.py b/applications/Chat/coati/experience_buffer/utils.py index c2a34212e2f4..baedbebd184f 100644 --- a/applications/Chat/coati/experience_buffer/utils.py +++ b/applications/Chat/coati/experience_buffer/utils.py @@ -21,6 +21,7 @@ class BufferItem: "A" is the number of actions. """ + sequences: torch.Tensor action_log_probs: torch.Tensor values: torch.Tensor @@ -33,8 +34,7 @@ class BufferItem: def split_experience_batch(experience: Experience) -> List[BufferItem]: batch_size = experience.sequences.size(0) batch_kwargs = [{} for _ in range(batch_size)] - keys = ('sequences', 'action_log_probs', 'values', - 'reward', 'advantages', 'attention_mask', 'action_mask') + keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask") for key in keys: value = getattr(experience, key) if isinstance(value, torch.Tensor): @@ -49,22 +49,21 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]: return items -def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor: - assert side in ('left', 'right') +def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor: + assert side in ("left", "right") max_len = max(seq.size(0) for seq in sequences) padded_sequences = [] for seq in sequences: pad_len = max_len - seq.size(0) - padding = (pad_len, 0) if side == 'left' else (0, pad_len) + padding = (pad_len, 0) if side == "left" else (0, pad_len) padded_sequences.append(F.pad(seq, padding)) return torch.stack(padded_sequences, dim=0) def make_experience_batch(items: List[BufferItem]) -> Experience: kwargs = {} - to_pad_keys = set(('action_log_probs', 'action_mask')) - keys = ('sequences', 'action_log_probs', 'values', - 'reward', 'advantages', 'attention_mask', 'action_mask') + to_pad_keys = set(("action_log_probs", "action_mask")) + keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask") for key in keys: vals = [getattr(item, key) for item in items] if key in to_pad_keys: diff --git a/applications/Chat/coati/experience_maker/__init__.py b/applications/Chat/coati/experience_maker/__init__.py index 39ca7576b227..06452292e77c 100644 --- a/applications/Chat/coati/experience_maker/__init__.py +++ b/applications/Chat/coati/experience_maker/__init__.py @@ -1,4 +1,4 @@ from .base import Experience, ExperienceMaker from .naive import NaiveExperienceMaker -__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker'] +__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"] diff --git a/applications/Chat/coati/experience_maker/base.py b/applications/Chat/coati/experience_maker/base.py index b4646f282f0c..727f0a4a52e8 100644 --- a/applications/Chat/coati/experience_maker/base.py +++ b/applications/Chat/coati/experience_maker/base.py @@ -24,6 +24,7 @@ class Experience: "A" is the number of actions. """ + sequences: torch.Tensor action_log_probs: torch.Tensor values: torch.Tensor @@ -58,13 +59,9 @@ def pin_memory(self): class ExperienceMaker(ABC): - - def __init__(self, - actor: Actor, - critic: nn.Module, - reward_model: nn.Module, - initial_model: Actor, - kl_coef: float = 0.1) -> None: + def __init__( + self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1 + ) -> None: super().__init__() self.actor = actor self.critic = critic diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py index 496f8ab445fc..30dfd8e0b9bc 100644 --- a/applications/Chat/coati/experience_maker/naive.py +++ b/applications/Chat/coati/experience_maker/naive.py @@ -23,22 +23,21 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie # calculate auxiliary tensors attention_mask = None - pad_token_id = generate_kwargs.get('pad_token_id', None) + pad_token_id = generate_kwargs.get("pad_token_id", None) if pad_token_id is not None: - attention_mask = sequences.not_equal(pad_token_id)\ - .to(dtype=torch.long, device=sequences.device) + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) input_len = input_ids.size(1) - eos_token_id = generate_kwargs.get('eos_token_id', None) + eos_token_id = generate_kwargs.get("eos_token_id", None) if eos_token_id is None: action_mask = torch.ones_like(sequences, dtype=torch.bool) else: # left padding may be applied, only mask action action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input action_mask[:, :input_len] = False action_mask = action_mask[:, 1:] - action_mask = action_mask[:, -(sequences.size(1) - input_len):] + action_mask = action_mask[:, -(sequences.size(1) - input_len) :] num_actions = action_mask.size(1) actor_output = self.actor(sequences, attention_mask) diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py index 230eedf7ecba..96d40c7c4709 100644 --- a/applications/Chat/coati/kernels/__init__.py +++ b/applications/Chat/coati/kernels/__init__.py @@ -1,6 +1,6 @@ from .wrapper import convert_to_xformer_model, recover_from_xformer_model __all__ = [ - 'convert_to_xformer_model', - 'recover_from_xformer_model', + "convert_to_xformer_model", + "recover_from_xformer_model", ] diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py index e99f9c2247d1..d1eb139187f3 100644 --- a/applications/Chat/coati/kernels/opt_attn.py +++ b/applications/Chat/coati/kernels/opt_attn.py @@ -21,11 +21,12 @@ def forward( output_attentions: bool = False, ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: if not self.training: - return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, - output_attentions) + return super().forward( + hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions + ) """Input shape: Batch x Time x Channel""" - assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask' - assert not output_attentions, 'Xformers attention does not support output_attentions' + assert layer_head_mask is None, "Xformers attention does not support layer_head_mask" + assert not output_attentions, "Xformers attention does not support output_attentions" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder @@ -69,12 +70,14 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = xops.memory_efficient_attention(query_states, - key_states, - value_states, - attn_bias=xops.LowerTriangularMask(), - p=self.dropout if self.training else 0.0, - scale=self.scaling) + attn_output = xops.memory_efficient_attention( + query_states, + key_states, + value_states, + attn_bias=xops.LowerTriangularMask(), + p=self.dropout if self.training else 0.0, + scale=self.scaling, + ) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. diff --git a/applications/Chat/coati/models/__init__.py b/applications/Chat/coati/models/__init__.py index 0a296a863756..ad4a525b4af2 100644 --- a/applications/Chat/coati/models/__init__.py +++ b/applications/Chat/coati/models/__init__.py @@ -3,6 +3,13 @@ from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss __all__ = [ - 'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss', - 'LoRAModule', 'convert_to_lora_module' + "Actor", + "Critic", + "RewardModel", + "PolicyLoss", + "ValueLoss", + "LogSigLoss", + "LogExpLoss", + "LoRAModule", + "convert_to_lora_module", ] diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py index c5f748a0c85a..5c9905bb2224 100644 --- a/applications/Chat/coati/models/base/__init__.py +++ b/applications/Chat/coati/models/base/__init__.py @@ -9,7 +9,7 @@ def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module: """Get the base model of our wrapper classes. - For Actor, Critic and RewardModel, return ``model.model``, + For Actor, Critic and RewardModel, return ``model.model``, it's usually a ``transformers.PreTrainedModel``. Args: @@ -18,9 +18,10 @@ def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module: Returns: nn.Module: the base model """ - assert isinstance(model, (Actor, Critic, RewardModel)), \ - f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.' + assert isinstance( + model, (Actor, Critic, RewardModel) + ), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first." return model.model -__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model'] +__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"] diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py index 6842f81d9b87..979f9318be50 100644 --- a/applications/Chat/coati/models/base/actor.py +++ b/applications/Chat/coati/models/base/actor.py @@ -16,18 +16,17 @@ class Actor(LoRAModule): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.convert_to_lora() def forward( - self, - input_ids: torch.LongTensor, - attention_mask: Optional[torch.Tensor] = None, - **model_kwargs, # HACK: `generate` method may pass more kwargs + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + **model_kwargs, # HACK: `generate` method may pass more kwargs ) -> torch.Tensor: - """Returns model output. - """ + """Returns model output.""" output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs) return output diff --git a/applications/Chat/coati/models/base/critic.py b/applications/Chat/coati/models/base/critic.py index e68a743a7762..54ab7fa47d48 100644 --- a/applications/Chat/coati/models/base/critic.py +++ b/applications/Chat/coati/models/base/critic.py @@ -23,22 +23,23 @@ def __init__( model: nn.Module, value_head: nn.Module, lora_rank: int = 0, - lora_train_bias: str = 'none', + lora_train_bias: str = "none", use_action_mask: bool = False, ) -> None: - super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.value_head = value_head self.use_action_mask = use_action_mask self.convert_to_lora() - def forward(self, - sequences: torch.LongTensor, - action_mask: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + sequences: torch.LongTensor, + action_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) - last_hidden_states = outputs['last_hidden_state'] + last_hidden_states = outputs["last_hidden_state"] values = self.value_head(last_hidden_states).squeeze(-1) diff --git a/applications/Chat/coati/models/base/reward_model.py b/applications/Chat/coati/models/base/reward_model.py index ce8c0a1d3568..1a70c6cc12bb 100644 --- a/applications/Chat/coati/models/base/reward_model.py +++ b/applications/Chat/coati/models/base/reward_model.py @@ -17,11 +17,13 @@ class RewardModel(LoRAModule): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - model: nn.Module, - value_head: Optional[nn.Module] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + model: nn.Module, + value_head: Optional[nn.Module] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) self.model = model self.convert_to_lora() @@ -35,7 +37,7 @@ def __init__(self, def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: outputs = self.model(sequences, attention_mask=attention_mask) - last_hidden_states = outputs['last_hidden_state'] + last_hidden_states = outputs["last_hidden_state"] values = self.value_head(last_hidden_states)[:, :-1] - value = values.mean(dim=1).squeeze(1) # ensure shape is (B) + value = values.mean(dim=1).squeeze(1) # ensure shape is (B) return value diff --git a/applications/Chat/coati/models/bloom/__init__.py b/applications/Chat/coati/models/bloom/__init__.py index d0e7f7b1ef94..7af199a67d3b 100644 --- a/applications/Chat/coati/models/bloom/__init__.py +++ b/applications/Chat/coati/models/bloom/__init__.py @@ -2,4 +2,4 @@ from .bloom_critic import BLOOMCritic from .bloom_rm import BLOOMRM -__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM'] +__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"] diff --git a/applications/Chat/coati/models/bloom/bloom_actor.py b/applications/Chat/coati/models/bloom/bloom_actor.py index d7577f096493..73855a2245e7 100644 --- a/applications/Chat/coati/models/bloom/bloom_actor.py +++ b/applications/Chat/coati/models/bloom/bloom_actor.py @@ -1,7 +1,6 @@ from typing import Optional -import torch -from transformers import BloomConfig, BloomForCausalLM, BloomModel +from transformers import BloomConfig, BloomForCausalLM from ..base import Actor @@ -18,12 +17,14 @@ class BLOOMActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = BloomForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/bloom/bloom_critic.py b/applications/Chat/coati/models/bloom/bloom_critic.py index a3716ca94138..b2d838f7ffc5 100644 --- a/applications/Chat/coati/models/bloom/bloom_critic.py +++ b/applications/Chat/coati/models/bloom/bloom_critic.py @@ -1,8 +1,7 @@ from typing import Optional -import torch import torch.nn as nn -from transformers import BloomConfig, BloomForCausalLM, BloomModel +from transformers import BloomConfig, BloomModel from ..base import Critic @@ -18,12 +17,14 @@ class BLOOMCritic(Critic): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = BloomModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/bloom/bloom_rm.py b/applications/Chat/coati/models/bloom/bloom_rm.py index e6ca9b1d4851..c09457ddc8c7 100644 --- a/applications/Chat/coati/models/bloom/bloom_rm.py +++ b/applications/Chat/coati/models/bloom/bloom_rm.py @@ -1,7 +1,7 @@ from typing import Optional import torch.nn as nn -from transformers import BloomConfig, BloomForCausalLM, BloomModel +from transformers import BloomConfig, BloomModel from ..base import RewardModel @@ -17,11 +17,13 @@ class BLOOMRM(RewardModel): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = BloomModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/chatglm/__init__.py b/applications/Chat/coati/models/chatglm/__init__.py index 373f19553fdc..5956f5a8e91b 100644 --- a/applications/Chat/coati/models/chatglm/__init__.py +++ b/applications/Chat/coati/models/chatglm/__init__.py @@ -1,3 +1,3 @@ from .chatglm_actor import ChatGLMActor -__all__ = ['ChatGLMActor'] \ No newline at end of file +__all__ = ["ChatGLMActor"] diff --git a/applications/Chat/coati/models/chatglm/chatglm_actor.py b/applications/Chat/coati/models/chatglm/chatglm_actor.py index c35d994e9319..00a61561ee47 100644 --- a/applications/Chat/coati/models/chatglm/chatglm_actor.py +++ b/applications/Chat/coati/models/chatglm/chatglm_actor.py @@ -1,11 +1,9 @@ from typing import Optional -import torch +from ..base import Actor from .configuration_chatglm import ChatGLMConfig from .modeling_chatglm import ChatGLMForConditionalGeneration -from ..base import Actor - class ChatGLMActor(Actor): """ @@ -19,10 +17,9 @@ class ChatGLMActor(Actor): do not support lora for now. """ - def __init__(self, - pretrained: str = None, - config: Optional[ChatGLMConfig] = None, - checkpoint: bool = False) -> None: + def __init__( + self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False + ) -> None: if pretrained is not None: model = ChatGLMForConditionalGeneration.from_pretrained(pretrained) elif config is not None: @@ -31,4 +28,4 @@ def __init__(self, model = ChatGLMForConditionalGeneration(ChatGLMConfig()) if checkpoint: model.gradient_checkpointing_enable() - super().__init__(model, lora_rank=0, lora_train_bias='none') + super().__init__(model, lora_rank=0, lora_train_bias="none") diff --git a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py index f7717f7e68b6..221ef044b470 100644 --- a/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py +++ b/applications/Chat/coati/models/chatglm/chatglm_tokenizer.py @@ -2,15 +2,14 @@ This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py """ """Tokenization classes for ChatGLM.""" -from typing import List, Optional, Union import os +from typing import Dict, List, Optional, Union -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.utils import logging, PaddingStrategy -from transformers.tokenization_utils_base import EncodedInput, BatchEncoding -from typing import Dict -import sentencepiece as spm import numpy as np +import sentencepiece as spm +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_base import BatchEncoding, EncodedInput +from transformers.utils import PaddingStrategy, logging logger = logging.get_logger(__name__) @@ -52,11 +51,11 @@ def __len__(self): class SPTokenizer: def __init__( - self, - vocab_file, - num_image_tokens=20000, - max_blank_length=80, - byte_fallback=True, + self, + vocab_file, + num_image_tokens=20000, + max_blank_length=80, + byte_fallback=True, ): assert vocab_file is not None self.vocab_file = vocab_file @@ -100,9 +99,7 @@ def _preprocess(self, text: str, linebreak=True, whitespaces=True): text = self._encode_whitespaces(text, max_len=self.max_blank_length) return text - def encode( - self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True - ) -> List[int]: + def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]: """ @param text: Text to encode. @param linebreak: Whether to encode newline (\n) in text. @@ -136,9 +133,7 @@ def decode_tokens(self, tokens: List[str]) -> str: text = self.postprocess(text) return text - def tokenize( - self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True - ) -> List[str]: + def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]: """ @param text: Text to encode. @param linebreak: Whether to encode newline (\n) in text. @@ -181,20 +176,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask", "position_ids"] def __init__( - self, - vocab_file, - do_lower_case=False, - remove_space=False, - bos_token='', - eos_token='', - end_token='', - mask_token='[MASK]', - gmask_token='[gMASK]', - padding_side="left", - pad_token="", - unk_token="", - num_image_tokens=20000, - **kwargs + self, + vocab_file, + do_lower_case=False, + remove_space=False, + bos_token="", + eos_token="", + end_token="", + mask_token="[MASK]", + gmask_token="[gMASK]", + padding_side="left", + pad_token="", + unk_token="", + num_image_tokens=20000, + **kwargs, ) -> None: super().__init__( do_lower_case=do_lower_case, @@ -208,7 +203,7 @@ def __init__( pad_token=pad_token, unk_token=unk_token, num_image_tokens=num_image_tokens, - **kwargs + **kwargs, ) self.do_lower_case = do_lower_case @@ -243,11 +238,11 @@ def end_token_id(self) -> Optional[int]: @property def vocab_size(self): - """ Returns vocab size """ + """Returns vocab size""" return self.sp_tokenizer.num_tokens def get_vocab(self): - """ Returns vocab as a dict """ + """Returns vocab as a dict""" vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab @@ -264,7 +259,7 @@ def preprocess_text(self, inputs): return outputs def _tokenize(self, text, **kwargs): - """ Returns a tokenized string. """ + """Returns a tokenized string.""" text = self.preprocess_text(text) seq = self.sp_tokenizer.tokenize(text) @@ -274,11 +269,7 @@ def _tokenize(self, text, **kwargs): def convert_tokens_to_string(self, tokens: List[str]) -> str: return self.sp_tokenizer.decode_tokens(tokens) - def _decode( - self, - token_ids: Union[int, List[int]], - **kwargs - ) -> str: + def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str: if isinstance(token_ids, int): token_ids = [token_ids] if len(token_ids) == 0: @@ -288,7 +279,7 @@ def _decode( return super()._decode(token_ids, **kwargs) def _convert_token_to_id(self, token): - """ Converts a token (str) in an id using the vocab. """ + """Converts a token (str) in an id using the vocab.""" return self.sp_tokenizer[token] def _convert_id_to_token(self, index): @@ -309,13 +300,11 @@ def save_vocabulary(self, save_directory, filename_prefix=None): `Tuple(str)`: Paths to the files saved. """ if os.path.isdir(save_directory): - vocab_file = os.path.join( - save_directory, self.vocab_files_names["vocab_file"] - ) + vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) else: vocab_file = save_directory - with open(self.vocab_file, 'rb') as fin: + with open(self.vocab_file, "rb") as fin: proto_str = fin.read() with open(vocab_file, "wb") as writer: @@ -324,7 +313,7 @@ def save_vocabulary(self, save_directory, filename_prefix=None): return (vocab_file,) def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and @@ -343,19 +332,19 @@ def build_inputs_with_special_tokens( `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ gmask_id = self.sp_tokenizer[self.gmask_token] - eos_id = self.sp_tokenizer[self.eos_token] + self.sp_tokenizer[self.eos_token] token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] if token_ids_1 is not None: token_ids_0 = token_ids_0 + token_ids_1 return token_ids_0 def _pad( - self, - encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], - max_length: Optional[int] = None, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - pad_to_multiple_of: Optional[int] = None, - return_attention_mask: Optional[bool] = None, + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, ) -> dict: """ Pad encoded inputs (on left/right and up to predefined length or max length in the batch) @@ -421,17 +410,23 @@ def _pad( mask_position = required_input.index(mask_token) position_ids[context_length:] = mask_position block_position_ids = np.concatenate( - [np.zeros(context_length, dtype=np.int64), - np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) + [ + np.zeros(context_length, dtype=np.int64), + np.arange(1, seq_length - context_length + 1, dtype=np.int64), + ] + ) encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) if needs_to_be_padded: difference = max_length - len(required_input) if "attention_mask" in encoded_inputs: - encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], - pad_width=[(0, 0), (difference, 0), (difference, 0)], - mode='constant', constant_values=True) + encoded_inputs["attention_mask"] = np.pad( + encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode="constant", + constant_values=True, + ) if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ "token_type_ids" @@ -439,8 +434,9 @@ def _pad( if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] if "position_ids" in encoded_inputs: - encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], - pad_width=[(0, 0), (difference, 0)]) + encoded_inputs["position_ids"] = np.pad( + encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)] + ) encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input - return encoded_inputs \ No newline at end of file + return encoded_inputs diff --git a/applications/Chat/coati/models/chatglm/configuration_chatglm.py b/applications/Chat/coati/models/chatglm/configuration_chatglm.py index d0e3f6cc63d7..a6d2ccd18715 100644 --- a/applications/Chat/coati/models/chatglm/configuration_chatglm.py +++ b/applications/Chat/coati/models/chatglm/configuration_chatglm.py @@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig): >>> # Accessing the model configuration >>> configuration = model.config - ``` -""" + ```""" model_type = "chatglm" def __init__( - self, - vocab_size=130528, - hidden_size=4096, - num_layers=28, - num_attention_heads=32, - layernorm_epsilon=1e-5, - use_cache=True, - bos_token_id=130004, - eos_token_id=130005, - mask_token_id=130000, - gmask_token_id=130001, - pad_token_id=3, - max_sequence_length=2048, - inner_hidden_size=16384, - position_encoding_2d=True, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs + self, + vocab_size=130528, + hidden_size=4096, + num_layers=28, + num_attention_heads=32, + layernorm_epsilon=1e-5, + use_cache=True, + bos_token_id=130004, + eos_token_id=130005, + mask_token_id=130000, + gmask_token_id=130001, + pad_token_id=3, + max_sequence_length=2048, + inner_hidden_size=16384, + position_encoding_2d=True, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, ): self.num_layers = num_layers self.vocab_size = vocab_size @@ -99,9 +98,4 @@ def __init__( self.pre_seq_len = pre_seq_len self.prefix_projection = prefix_projection - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs - ) \ No newline at end of file + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/applications/Chat/coati/models/chatglm/modeling_chatglm.py b/applications/Chat/coati/models/chatglm/modeling_chatglm.py index 77e7d0d8ea09..d1d15c68ffd8 100644 --- a/applications/Chat/coati/models/chatglm/modeling_chatglm.py +++ b/applications/Chat/coati/models/chatglm/modeling_chatglm.py @@ -4,41 +4,40 @@ """ PyTorch ChatGLM model. """ -import math import copy +import math import os -import warnings import re import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint import torch.nn.functional as F +import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any - -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, -) +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList from transformers.modeling_outputs import ( BaseModelOutputWithPast, - CausalLMOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) from .configuration_chatglm import ChatGLMConfig # flags required to enable jit fusion kernels -if sys.platform != 'darwin': +if sys.platform != "darwin": torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) @@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name ): logger.info(f"Skipping {'/'.join(name)}") continue @@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): array = np.transpose(array) try: assert ( - pointer.shape == array.shape + pointer.shape == array.shape ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" except AssertionError as e: e.args += (pointer.shape, array.shape) @@ -153,7 +152,7 @@ def __init__(self, config): self.trans = torch.nn.Sequential( torch.nn.Linear(config.hidden_size, config.hidden_size), torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2), ) else: self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) @@ -170,8 +169,7 @@ def forward(self, prefix: torch.Tensor): @torch.jit.script def gelu_impl(x): """OpenAI's gelu implementation.""" - return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * - (1.0 + 0.044715 * x * x))) + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) def gelu(x): @@ -181,21 +179,22 @@ def gelu(x): class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000, precision=torch.half, learnable=False): super().__init__() - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = inv_freq.half() self.learnable = learnable if learnable: self.inv_freq = torch.nn.Parameter(inv_freq) self.max_seq_len_cached = None else: - self.register_buffer('inv_freq', inv_freq) + self.register_buffer("inv_freq", inv_freq) self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None self.precision = precision - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): pass def forward(self, x, seq_dim=1, seq_len=None): @@ -204,7 +203,7 @@ def forward(self, x, seq_dim=1, seq_len=None): if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = None if self.learnable else seq_len t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) if self.precision == torch.bfloat16: @@ -230,30 +229,31 @@ def _apply(self, fn): def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions @torch.jit.script def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] - cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ - F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding( + position_id, sin.squeeze(1) + ).unsqueeze(2) q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return q, k def attention_fn( - self, - query_layer, - key_layer, - value_layer, - attention_mask, - hidden_size_per_partition, - layer_id, - layer_past=None, - scaling_attention_score=True, - use_cache=False, + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, ): if layer_past is not None: past_key, past_value = layer_past[0], layer_past[1] @@ -285,7 +285,9 @@ def attention_fn( key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) matmul_result = torch.zeros( - 1, 1, 1, + 1, + 1, + 1, dtype=query_layer.dtype, device=query_layer.device, ) @@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs): class SelfAttention(torch.nn.Module): - def __init__(self, hidden_size, num_attention_heads, - layer_id, hidden_size_per_attention_head=None, bias=True, - params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + def __init__( + self, + hidden_size, + num_attention_heads, + layer_id, + hidden_size_per_attention_head=None, + bias=True, + params_dtype=torch.float, + position_encoding_2d=True, + empty_init=True, + ): if empty_init: init_method = skip_init else: @@ -410,8 +420,7 @@ def attention_mask_func(attention_scores, attention_mask): attention_scores.masked_fill_(attention_mask, -10000.0) return attention_scores - def split_tensor_along_last_dim(self, tensor, num_partitions, - contiguous_split_chunks=False): + def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False): """Split a tensor along its last dimension. Arguments: tensor: input tensor. @@ -431,14 +440,14 @@ def split_tensor_along_last_dim(self, tensor, num_partitions, return tensor_list def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, ): """ hidden_states: [seq_len, batch, hidden_size] @@ -462,8 +471,10 @@ def forward( q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) - position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ - position_ids[:, 1, :].transpose(0, 1).contiguous() + position_ids, block_position_ids = ( + position_ids[:, 0, :].transpose(0, 1).contiguous(), + position_ids[:, 1, :].transpose(0, 1).contiguous(), + ) q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) @@ -484,7 +495,7 @@ def forward( hidden_size_per_partition=self.hidden_size_per_partition, layer_id=layer_id, layer_past=layer_past, - use_cache=use_cache + use_cache=use_cache, ) output = self.dense(context_layer) @@ -509,8 +520,16 @@ def forward(self, x): class GLU(torch.nn.Module): - def __init__(self, hidden_size, inner_hidden_size=None, - layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): + def __init__( + self, + hidden_size, + inner_hidden_size=None, + layer_id=None, + bias=True, + activation_func=gelu, + params_dtype=torch.float, + empty_init=True, + ): super(GLU, self).__init__() if empty_init: init_method = skip_init @@ -557,19 +576,19 @@ def forward(self, hidden_states): class GLMBlock(torch.nn.Module): def __init__( - self, - hidden_size, - num_attention_heads, - layernorm_epsilon, - layer_id, - inner_hidden_size=None, - hidden_size_per_attention_head=None, - layernorm=LayerNorm, - use_bias=True, - params_dtype=torch.float, - num_layers=28, - position_encoding_2d=True, - empty_init=True + self, + hidden_size, + num_attention_heads, + layernorm_epsilon, + layer_id, + inner_hidden_size=None, + hidden_size_per_attention_head=None, + layernorm=LayerNorm, + use_bias=True, + params_dtype=torch.float, + num_layers=28, + position_encoding_2d=True, + empty_init=True, ): super(GLMBlock, self).__init__() # Set output layer initialization if not provided. @@ -590,7 +609,7 @@ def __init__( bias=use_bias, params_dtype=params_dtype, position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init + empty_init=empty_init, ) # Layernorm on the input data. @@ -605,18 +624,18 @@ def __init__( bias=use_bias, layer_id=layer_id, params_dtype=params_dtype, - empty_init=empty_init + empty_init=empty_init, ) def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + output_attentions: bool = False, ): """ hidden_states: [seq_len, batch, hidden_size] @@ -635,7 +654,7 @@ def forward( layer_id=layer_id, layer_past=layer_past, use_cache=use_cache, - output_attentions=output_attentions + output_attentions=output_attentions, ) attention_output = attention_outputs[0] @@ -702,10 +721,15 @@ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) for i, context_length in enumerate(context_lengths): position_ids[i, context_length:] = mask_positions[i] - block_position_ids = [torch.cat(( - torch.zeros(context_length, dtype=torch.long, device=device), - torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 - )) for context_length in context_lengths] + block_position_ids = [ + torch.cat( + ( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1, + ) + ) + for context_length in context_lengths + ] block_position_ids = torch.stack(block_position_ids, dim=0) position_ids = torch.stack((position_ids, block_position_ids), dim=1) else: @@ -823,9 +847,7 @@ def __init__(self, config: ChatGLMConfig, empty_init=True): self.prefix_projection = config.prefix_projection self.word_embeddings = init_method( - torch.nn.Embedding, - num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, - dtype=self.params_dtype + torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype ) self.gradient_checkpointing = False @@ -841,12 +863,10 @@ def get_layer(layer_id): use_bias=True, params_dtype=self.params_dtype, position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init + empty_init=empty_init, ) - self.layers = torch.nn.ModuleList( - [get_layer(layer_id) for layer_id in range(self.num_layers)] - ) + self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)]) # Final layer norm before output. self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) @@ -876,7 +896,7 @@ def get_prompt(self, batch_size, device, dtype=torch.half): self.pre_seq_len, self.num_layers * 2, self.num_attention_heads, - self.hidden_size // self.num_attention_heads + self.hidden_size // self.num_attention_heads, ) # seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) @@ -891,18 +911,17 @@ def get_prompt(self, batch_size, device, dtype=torch.half): config_class=_CONFIG_FOR_DOC, ) def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -931,17 +950,14 @@ def forward( if past_key_values is None: if self.pre_seq_len is not None: - past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, - dtype=inputs_embeds.dtype) + past_key_values = self.get_prompt( + batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype + ) else: past_key_values = tuple([None] * len(self.layers)) if attention_mask is None: - attention_mask = self.get_masks( - input_ids, - device=input_ids.device - ) - + attention_mask = self.get_masks(input_ids, device=input_ids.device) if position_ids is None: MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id @@ -955,15 +971,13 @@ def forward( use_gmasks.append(use_gmask) position_ids = self.get_position_ids( - input_ids, - mask_positions=mask_positions, - device=input_ids.device, - use_gmasks=use_gmasks + input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks ) if self.pre_seq_len is not None and attention_mask is not None: prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( - attention_mask.device) + attention_mask.device + ) prefix_attention_mask = (prefix_attention_mask < 0.5).bool() attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) @@ -980,7 +994,6 @@ def forward( attention_mask = attention_mask.to(hidden_states.device) for i, layer in enumerate(self.layers): - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_past = past_key_values[i] @@ -994,7 +1007,7 @@ def forward( torch.tensor(i), layer_past, use_cache, - output_attentions + output_attentions, ) else: layer_ret = layer( @@ -1004,7 +1017,7 @@ def forward( layer_id=torch.tensor(i), layer_past=layer_past, use_cache=use_cache, - output_attentions=output_attentions + output_attentions=output_attentions, ) hidden_states = layer_ret[0] @@ -1049,13 +1062,7 @@ def __init__(self, config: ChatGLMConfig, empty_init=True): self.transformer = ChatGLMModel(config, empty_init=empty_init) - self.lm_head = init_method( - nn.Linear, - config.hidden_size, - config.vocab_size, - bias=False, - dtype=torch.half - ) + self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half) self.config = config @@ -1087,32 +1094,29 @@ def _update_model_kwargs_for_generation( attention_mask = model_kwargs["attention_mask"] if attention_mask is not None and attention_mask.dtype == torch.bool: attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) + [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3 + ) new_attention_mask = attention_mask[:, :, -1:].clone() new_attention_mask[..., -1] = False - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, new_attention_mask], dim=2 - ) + model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2) # update position ids if "position_ids" in model_kwargs: position_ids = model_kwargs["position_ids"] new_position_id = position_ids[..., -1:].clone() new_position_id[:, 1, :] += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) return model_kwargs def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past: Optional[torch.Tensor] = None, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - **kwargs + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs, ) -> dict: batch_size, seq_length = input_ids.shape MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id @@ -1137,11 +1141,17 @@ def prepare_inputs_for_generation( context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] if self.position_encoding_2d: position_ids = torch.tensor( - [[mask_position, seq_length - context_length] for mask_position, context_length in - zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + [ + [mask_position, seq_length - context_length] + for mask_position, context_length in zip(mask_positions, context_lengths) + ], + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(-1) else: - position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, - device=input_ids.device).unsqueeze(-1) + position_ids = torch.tensor( + [mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device + ).unsqueeze(-1) if past is None: past = past_key_values @@ -1149,44 +1159,38 @@ def prepare_inputs_for_generation( "input_ids": last_token, "past_key_values": past, "position_ids": position_ids, - "attention_mask": attention_mask + "attention_mask": attention_mask, } else: if attention_mask is not None and attention_mask.dtype != torch.bool: logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") attention_mask = None if attention_mask is None: - attention_mask = self.get_masks( - input_ids, - device=input_ids.device - ) + attention_mask = self.get_masks(input_ids, device=input_ids.device) if position_ids is None: position_ids = self.get_position_ids( - input_ids, - device=input_ids.device, - mask_positions=mask_positions, - use_gmasks=use_gmasks + input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks ) return { "input_ids": input_ids, "past_key_values": past, "position_ids": position_ids, - "attention_mask": attention_mask + "attention_mask": attention_mask, } def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1235,7 +1239,7 @@ def forward( @staticmethod def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or @@ -1268,15 +1272,33 @@ def process_response(self, response): return response @torch.no_grad() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + num_beams=1, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs, + ): if history is None: history = [] if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} + gen_kwargs = { + "max_length": max_length, + "num_beams": num_beams, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } if not history: prompt = query else: @@ -1287,22 +1309,38 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max inputs = tokenizer([prompt], return_tensors="pt") inputs = inputs.to(self.device) outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] response = tokenizer.decode(outputs) response = self.process_response(response) history = history + [(query, response)] return response, history @torch.no_grad() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = None, + max_length: int = 2048, + do_sample=True, + top_p=0.7, + temperature=0.95, + logits_processor=None, + **kwargs, + ): if history is None: history = [] if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} + gen_kwargs = { + "max_length": max_length, + "do_sample": do_sample, + "top_p": top_p, + "temperature": temperature, + "logits_processor": logits_processor, + **kwargs, + } if not history: prompt = query else: @@ -1313,7 +1351,7 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No inputs = tokenizer([prompt], return_tensors="pt") inputs = inputs.to(self.device) for outputs in self.stream_generate(**inputs, **gen_kwargs): - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] response = tokenizer.decode(outputs) response = self.process_response(response) new_history = history + [(query, response)] @@ -1321,13 +1359,13 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No @torch.no_grad() def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - **kwargs, + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + **kwargs, ): batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index de0d63f95f50..e3afac88c7a7 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -16,9 +16,9 @@ from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper -def _prepare_logits_processor(top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None) -> LogitsProcessorList: +def _prepare_logits_processor( + top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None +) -> LogitsProcessorList: processor_list = LogitsProcessorList() if temperature is not None and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) @@ -37,18 +37,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: return unfinished_sequences.max() == 0 -def _sample(model: Actor, - input_ids: torch.Tensor, - max_length: int, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: +def _sample( + model: Actor, + input_ids: torch.Tensor, + max_length: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> torch.Tensor: if input_ids.size(1) >= max_length: return input_ids @@ -56,11 +58,12 @@ def _sample(model: Actor, unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) for _ in range(input_ids.size(1), max_length): - model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \ - if prepare_inputs_fn is not None else {'input_ids': input_ids} + model_inputs = ( + prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids} + ) outputs = model(**model_inputs) - next_token_logits = outputs['logits'][:, -1, :] + next_token_logits = outputs["logits"][:, -1, :] # pre-process distribution next_token_logits = logits_processor(input_ids, next_token_logits) # sample @@ -90,20 +93,22 @@ def _sample(model: Actor, @torch.no_grad() -def generate(model: Actor, - input_ids: torch.Tensor, - max_length: int, - num_beams: int = 1, - do_sample: bool = True, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: +def generate( + model: Actor, + input_ids: torch.Tensor, + max_length: int, + num_beams: int = 1, + do_sample: bool = True, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> torch.Tensor: """Generate token sequence. The returned sequence is input_ids + generated_tokens. Args: @@ -121,26 +126,28 @@ def generate(model: Actor, prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None. update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None. """ - is_greedy_gen_mode = ((num_beams == 1) and do_sample is False) - is_sample_gen_mode = ((num_beams == 1) and do_sample is True) - is_beam_gen_mode = ((num_beams > 1) and do_sample is False) + is_greedy_gen_mode = (num_beams == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and do_sample is False if is_greedy_gen_mode: # run greedy search raise NotImplementedError elif is_sample_gen_mode: # run sample - return _sample(model, - input_ids, - max_length, - early_stopping=early_stopping, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - top_k=top_k, - top_p=top_p, - temperature=temperature, - prepare_inputs_fn=prepare_inputs_fn, - update_model_kwargs_fn=update_model_kwargs_fn, - **model_kwargs) + return _sample( + model, + input_ids, + max_length, + early_stopping=early_stopping, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + top_k=top_k, + top_p=top_p, + temperature=temperature, + prepare_inputs_fn=prepare_inputs_fn, + update_model_kwargs_fn=update_model_kwargs_fn, + **model_kwargs, + ) elif is_beam_gen_mode: raise NotImplementedError else: diff --git a/applications/Chat/coati/models/gpt/__init__.py b/applications/Chat/coati/models/gpt/__init__.py index 63dc5ab0f5ea..823cf4a75e0d 100644 --- a/applications/Chat/coati/models/gpt/__init__.py +++ b/applications/Chat/coati/models/gpt/__init__.py @@ -2,4 +2,4 @@ from .gpt_critic import GPTCritic from .gpt_rm import GPTRM -__all__ = ['GPTActor', 'GPTCritic', 'GPTRM'] +__all__ = ["GPTActor", "GPTCritic", "GPTRM"] diff --git a/applications/Chat/coati/models/gpt/gpt_actor.py b/applications/Chat/coati/models/gpt/gpt_actor.py index ae9d669f1f56..a7e4b9bc3e22 100644 --- a/applications/Chat/coati/models/gpt/gpt_actor.py +++ b/applications/Chat/coati/models/gpt/gpt_actor.py @@ -18,13 +18,15 @@ class GPTActor(Actor): lora_train_bias (str): Bias training strategy for the LoRa layer. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = GPT2LMHeadModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/gpt/gpt_critic.py b/applications/Chat/coati/models/gpt/gpt_critic.py index 01e1cd10ef57..22ab36dea276 100644 --- a/applications/Chat/coati/models/gpt/gpt_critic.py +++ b/applications/Chat/coati/models/gpt/gpt_critic.py @@ -18,12 +18,14 @@ class GPTCritic(Critic): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = GPT2Model.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/gpt/gpt_rm.py b/applications/Chat/coati/models/gpt/gpt_rm.py index e52a5a14c1da..8edfc4008466 100644 --- a/applications/Chat/coati/models/gpt/gpt_rm.py +++ b/applications/Chat/coati/models/gpt/gpt_rm.py @@ -18,11 +18,13 @@ class GPTRM(RewardModel): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[GPT2Config] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[GPT2Config] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = GPT2Model.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/llama/__init__.py b/applications/Chat/coati/models/llama/__init__.py index 9b2a024afdb2..c87d732538a9 100644 --- a/applications/Chat/coati/models/llama/__init__.py +++ b/applications/Chat/coati/models/llama/__init__.py @@ -2,4 +2,4 @@ from .llama_critic import LlamaCritic from .llama_rm import LlamaRM -__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM'] +__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"] diff --git a/applications/Chat/coati/models/llama/llama_actor.py b/applications/Chat/coati/models/llama/llama_actor.py index 2c7adb390d8b..f1d9406835ca 100644 --- a/applications/Chat/coati/models/llama/llama_actor.py +++ b/applications/Chat/coati/models/llama/llama_actor.py @@ -1,7 +1,6 @@ from typing import Optional -import torch -from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import LlamaConfig, LlamaForCausalLM from ..base import Actor @@ -18,13 +17,14 @@ class LlamaActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = LlamaForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index a67e5de5def6..000dce17ccf0 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -17,13 +17,14 @@ class LlamaCritic(Critic): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: - + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = LlamaModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/llama/llama_rm.py b/applications/Chat/coati/models/llama/llama_rm.py index d6b62922686e..43bc9e638dc7 100644 --- a/applications/Chat/coati/models/llama/llama_rm.py +++ b/applications/Chat/coati/models/llama/llama_rm.py @@ -1,7 +1,7 @@ from typing import Optional import torch.nn as nn -from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel +from transformers import LlamaConfig, LlamaModel from ..base import RewardModel @@ -17,12 +17,13 @@ class LlamaRM(RewardModel): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[LlamaConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: - + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[LlamaConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = LlamaModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/lora.py b/applications/Chat/coati/models/lora.py index f1597da540a7..2114913e107b 100644 --- a/applications/Chat/coati/models/lora.py +++ b/applications/Chat/coati/models/lora.py @@ -8,8 +8,7 @@ class LoraLinear(lora.LoRALayer, nn.Module): - """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear. - """ + """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.""" def __init__( self, @@ -17,16 +16,14 @@ def __init__( bias: Optional[nn.Parameter], r: int = 0, lora_alpha: int = 1, - lora_dropout: float = 0., - fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) merge_weights: bool = True, ): nn.Module.__init__(self) - lora.LoRALayer.__init__(self, - r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - merge_weights=merge_weights) + lora.LoRALayer.__init__( + self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights + ) self.weight = weight self.bias = bias @@ -47,13 +44,12 @@ def __init__( self.weight.data = self.weight.data.T def reset_parameters(self): - if hasattr(self, 'lora_A'): + if hasattr(self, "lora_A"): # Initialize A with the default values for nn.Linear and set B to zero. nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): - def T(w): return w.T if self.fan_in_fan_out else w @@ -71,7 +67,6 @@ def T(w): self.merged = False def eval(self): - def T(w): return w.T if self.fan_in_fan_out else w @@ -80,12 +75,11 @@ def T(w): # Merge the weights and mark it if self.r > 0: self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling - delattr(self, 'lora_A') - delattr(self, 'lora_B') + delattr(self, "lora_A") + delattr(self, "lora_B") self.merged = True def forward(self, x: torch.Tensor): - def T(w): return w.T if self.fan_in_fan_out else w @@ -99,7 +93,9 @@ def T(w): def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear: - assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})' + assert ( + lora_rank <= linear.in_features + ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})" lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False) return lora_linear @@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None: _convert_to_lora_recursively(child, lora_rank) -def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module: +def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module: """Convert a torch.nn.Module to a LoRA module. Args: @@ -140,7 +136,7 @@ class LoRAModule(nn.Module): Defaults to 'none'. """ - def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: + def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None: super().__init__() self.lora_rank = lora_rank self.lora_train_bias = lora_train_bias diff --git a/applications/Chat/coati/models/loss.py b/applications/Chat/coati/models/loss.py index 05a0b4821797..4ad4f4dcd275 100644 --- a/applications/Chat/coati/models/loss.py +++ b/applications/Chat/coati/models/loss.py @@ -31,11 +31,13 @@ def __init__(self, clip_eps: float = 0.2) -> None: super().__init__() self.clip_eps = clip_eps - def forward(self, - log_probs: torch.Tensor, - old_log_probs: torch.Tensor, - advantages: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + log_probs: torch.Tensor, + old_log_probs: torch.Tensor, + advantages: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: ratio = (log_probs - old_log_probs).exp() surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages @@ -55,14 +57,16 @@ def __init__(self, clip_eps: float = 0.4) -> None: super().__init__() self.clip_eps = clip_eps - def forward(self, - values: torch.Tensor, - old_values: torch.Tensor, - reward: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + values: torch.Tensor, + old_values: torch.Tensor, + reward: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) - surr1 = (values_clipped - reward)**2 - surr2 = (values - reward)**2 + surr1 = (values_clipped - reward) ** 2 + surr2 = (values - reward) ** 2 loss = torch.max(surr1, surr2) loss = loss.mean() return 0.5 * loss diff --git a/applications/Chat/coati/models/opt/__init__.py b/applications/Chat/coati/models/opt/__init__.py index 334f4df0032a..e37d6e45c8fc 100644 --- a/applications/Chat/coati/models/opt/__init__.py +++ b/applications/Chat/coati/models/opt/__init__.py @@ -2,4 +2,4 @@ from .opt_critic import OPTCritic from .opt_rm import OPTRM -__all__ = ['OPTActor', 'OPTCritic', 'OPTRM'] +__all__ = ["OPTActor", "OPTCritic", "OPTRM"] diff --git a/applications/Chat/coati/models/opt/opt_actor.py b/applications/Chat/coati/models/opt/opt_actor.py index c14e4377ffb2..cd8908e13fb8 100644 --- a/applications/Chat/coati/models/opt/opt_actor.py +++ b/applications/Chat/coati/models/opt/opt_actor.py @@ -18,12 +18,14 @@ class OPTActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - checkpoint: bool = False, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = OPTForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/opt/opt_critic.py b/applications/Chat/coati/models/opt/opt_critic.py index f66c4173fa52..f37d28812c27 100644 --- a/applications/Chat/coati/models/opt/opt_critic.py +++ b/applications/Chat/coati/models/opt/opt_critic.py @@ -18,12 +18,14 @@ class OPTCritic(Critic): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none', - **kwargs) -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + **kwargs, + ) -> None: if pretrained is not None: model = OPTModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/opt/opt_rm.py b/applications/Chat/coati/models/opt/opt_rm.py index 6f75344e6aae..893708344ad4 100644 --- a/applications/Chat/coati/models/opt/opt_rm.py +++ b/applications/Chat/coati/models/opt/opt_rm.py @@ -17,11 +17,13 @@ class OPTRM(RewardModel): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: Optional[str] = None, - config: Optional[OPTConfig] = None, - lora_rank: int = 0, - lora_train_bias: str = 'none') -> None: + def __init__( + self, + pretrained: Optional[str] = None, + config: Optional[OPTConfig] = None, + lora_rank: int = 0, + lora_train_bias: str = "none", + ) -> None: if pretrained is not None: model = OPTModel.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index 97637d3523b0..def6190dd71c 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -4,9 +4,9 @@ import torch.nn.functional as F -def _compute_approx_kl(log_probs: torch.Tensor, - log_probs_base: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: +def _compute_approx_kl( + log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None +) -> torch.Tensor: """ Compute the approximate KL divergence between two distributions. Schulman blog: http://joschu.net/blog/kl-approx.html @@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor, return approx_kl -def compute_reward(r: Union[torch.Tensor, float], - kl_coef: float, - log_probs: torch.Tensor, - log_probs_base: torch.Tensor, - action_mask: Optional[torch.Tensor] = None) -> torch.Tensor: +def compute_reward( + r: Union[torch.Tensor, float], + kl_coef: float, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: if kl_coef <= 0.0: return r kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask) @@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num Returns: torch.Tensor: Action log probs. """ - logits = output['logits'] + logits = output["logits"] log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] diff --git a/applications/Chat/coati/quant/__init__.py b/applications/Chat/coati/quant/__init__.py index a65a78d07bb8..1765b8091bc3 100644 --- a/applications/Chat/coati/quant/__init__.py +++ b/applications/Chat/coati/quant/__init__.py @@ -2,6 +2,6 @@ from .utils import low_resource_init __all__ = [ - 'llama_load_quant', - 'low_resource_init', + "llama_load_quant", + "low_resource_init", ] diff --git a/applications/Chat/coati/quant/llama_gptq/__init__.py b/applications/Chat/coati/quant/llama_gptq/__init__.py index 51c8d6316290..51d5233586ad 100644 --- a/applications/Chat/coati/quant/llama_gptq/__init__.py +++ b/applications/Chat/coati/quant/llama_gptq/__init__.py @@ -1,5 +1,5 @@ from .loader import load_quant __all__ = [ - 'load_quant', + "load_quant", ] diff --git a/applications/Chat/coati/quant/llama_gptq/loader.py b/applications/Chat/coati/quant/llama_gptq/loader.py index 5353dc8a2ea3..50486337a7ab 100644 --- a/applications/Chat/coati/quant/llama_gptq/loader.py +++ b/applications/Chat/coati/quant/llama_gptq/loader.py @@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int): # ignore lm head layers = find_layers(model) - for name in ['lm_head']: + for name in ["lm_head"]: if name in layers: del layers[name] make_quant(model, layers, wbits, groupsize) - if checkpoint.endswith('.safetensors'): + if checkpoint.endswith(".safetensors"): from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) else: model.load_state_dict(torch.load(checkpoint)) diff --git a/applications/Chat/coati/quant/llama_gptq/model_utils.py b/applications/Chat/coati/quant/llama_gptq/model_utils.py index 62db171abb52..18e4e4761500 100644 --- a/applications/Chat/coati/quant/llama_gptq/model_utils.py +++ b/applications/Chat/coati/quant/llama_gptq/model_utils.py @@ -1,13 +1,12 @@ # copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py -import torch import torch.nn as nn -def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): if type(module) in layers: return {name: module} res = {} for name1, child in module.named_children(): - res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) + res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) return res diff --git a/applications/Chat/coati/quant/llama_gptq/quant.py b/applications/Chat/coati/quant/llama_gptq/quant.py index f7d5b7ce4bd8..5a7e2e72dfc5 100644 --- a/applications/Chat/coati/quant/llama_gptq/quant.py +++ b/applications/Chat/coati/quant/llama_gptq/quant.py @@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq): class Quantizer(nn.Module): - def __init__(self, shape=1): super(Quantizer, self).__init__() - self.register_buffer('maxq', torch.tensor(0)) - self.register_buffer('scale', torch.zeros(shape)) - self.register_buffer('zero', torch.zeros(shape)) + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) - def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8): + def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8): self.maxq = torch.tensor(2**bits - 1) self.perchannel = perchannel self.sym = sym @@ -68,7 +67,7 @@ def find_params(self, x, weight=False): self.zero = torch.round(-xmin / self.scale) if self.mse: - best = torch.full([x.shape[0]], float('inf'), device=dev) + best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid xmin1 = p * xmin @@ -123,13 +122,12 @@ def ready(self): try: import quant_cuda except: - print('CUDA extension not installed.') + print("CUDA extension not installed.") # Assumes layer is perfectly divisible into 256 * 256 blocks class QuantLinear(nn.Module): - def __init__(self, bits, groupsize, infeatures, outfeatures): super().__init__() if bits not in [2, 3, 4, 8]: @@ -142,11 +140,11 @@ def __init__(self, bits, groupsize, infeatures, outfeatures): groupsize = groupsize if groupsize != -1 else infeatures self.groupsize = groupsize self.register_buffer( - 'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), - dtype=torch.int)) - self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures))) - self.register_buffer('bias', torch.zeros(outfeatures)) - self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)) + "qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int) + ) + self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures))) + self.register_buffer("bias", torch.zeros(outfeatures)) + self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)) self._initialized_quant_state = False def pack(self, linear, scales, zeros): @@ -161,8 +159,10 @@ def pack(self, linear, scales, zeros): for idx in range(self.infeatures): g_idx = idx // self.groupsize intweight.append( - torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:, - None]) + torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[ + :, None + ] + ) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(np.uint32) @@ -271,13 +271,13 @@ def forward(self, x): return y.reshape(outshape) -def make_quant(module, names, bits, groupsize, name=''): +def make_quant(module, names, bits, groupsize, name=""): if isinstance(module, QuantLinear): return for attr in dir(module): tmp = getattr(module, attr) - name1 = name + '.' + attr if name != '' else attr + name1 = name + "." + attr if name != "" else attr if name1 in names: setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)) for name1, child in module.named_children(): - make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1) diff --git a/applications/Chat/coati/quant/utils.py b/applications/Chat/coati/quant/utils.py index 01b8cff0add1..d102bb30f52d 100644 --- a/applications/Chat/coati/quant/utils.py +++ b/applications/Chat/coati/quant/utils.py @@ -9,8 +9,7 @@ def _noop(*args, **kwargs): @contextmanager def low_resource_init(): - """This context manager disables weight initialization and sets the default float dtype to half. - """ + """This context manager disables weight initialization and sets the default float dtype to half.""" old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_ old_uniform_ = torch.nn.init.uniform_ old_normal_ = torch.nn.init.normal_ diff --git a/applications/Chat/coati/ray/callbacks/base.py b/applications/Chat/coati/ray/callbacks/base.py index 3306150a41ff..8c5bd8a67776 100644 --- a/applications/Chat/coati/ray/callbacks/base.py +++ b/applications/Chat/coati/ray/callbacks/base.py @@ -5,7 +5,7 @@ class TrainerCallback(ABC): """ - Base callback class. It defines the interface for callbacks. + Base callback class. It defines the interface for callbacks. """ def on_fit_start(self) -> None: @@ -40,7 +40,6 @@ def on_update_end(self) -> None: class MakerCallback(ABC): - def on_loop_start(self) -> None: pass diff --git a/applications/Chat/coati/ray/callbacks/performance_evaluator.py b/applications/Chat/coati/ray/callbacks/performance_evaluator.py index d3df8f9ae3e0..18798bce7dce 100644 --- a/applications/Chat/coati/ray/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/ray/callbacks/performance_evaluator.py @@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float: class Timer: - def __init__(self) -> None: self.start_time: Optional[float] = None - self.duration: float = 0. + self.duration: float = 0.0 def start(self) -> None: self.start_time = time() @@ -42,13 +41,13 @@ def end(self) -> None: self.duration += time() - self.start_time def reset(self) -> None: - self.duration = 0. + self.duration = 0.0 class ExperienceMakerPerformanceEvaluator(MakerCallback): - - def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, - reward_model_num_params: int) -> None: + def __init__( + self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int + ) -> None: super().__init__() self.world_size = get_world_size() self.actor_num_params = actor_num_params @@ -63,7 +62,7 @@ def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_ self.make_experience_flop: int = 0 print_rank_0( - f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}' + f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}" ) def on_make_experience_start(self) -> None: @@ -110,27 +109,29 @@ def on_loop_end(self) -> None: avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12) avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size) - avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \ - (self.total_samples * self.world_size) + avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / ( + self.total_samples * self.world_size + ) avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size) print_rank_0( - 'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' - + f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n' - + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' - + f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n' - - + f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n' + "Making Experience Performance Summary:\n" + + f"Throughput: {avg_throughput:.3f} samples/sec\n" + + f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n" + + f"Sample time (overall): {avg_time_per_sample:.3f} s\n" + + f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n" + + f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n" ) class TrainerPerformanceEvaluator(TrainerCallback): - - def __init__(self, - actor_num_params: int, - critic_num_params: int, - enable_grad_checkpoint: bool = False, - ignore_first_episodes: int = 1) -> None: + def __init__( + self, + actor_num_params: int, + critic_num_params: int, + enable_grad_checkpoint: bool = False, + ignore_first_episodes: int = 1, + ) -> None: super().__init__() self.world_size = get_world_size() self.actor_num_params = actor_num_params @@ -146,7 +147,7 @@ def __init__(self, self.learn_flop: int = 0 print_rank_0( - f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}' + f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}" ) def on_episode_start(self, episodes: int) -> None: @@ -191,7 +192,7 @@ def on_update_end(self) -> None: def on_fit_end(self) -> None: if self.total_samples == 0: - print_rank_0('No samples are collected, skip trainer performance evaluation') + print_rank_0("No samples are collected, skip trainer performance evaluation") return avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size) avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size) @@ -204,9 +205,10 @@ def on_fit_end(self) -> None: avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size) print_rank_0( - 'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n' - + f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n' - + f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n' - - + f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n' + "Learning Performance Summary:\n" + + f"Throughput: {avg_throughput:.3f} samples/sec\n" + + f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n" + + f"Sample time (overall): {avg_time_per_sample:.3f} s\n" + + f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n" + + f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n" ) diff --git a/applications/Chat/coati/ray/detached_replay_buffer.py b/applications/Chat/coati/ray/detached_replay_buffer.py index e04bf5ccb881..92dab17292f7 100644 --- a/applications/Chat/coati/ray/detached_replay_buffer.py +++ b/applications/Chat/coati/ray/detached_replay_buffer.py @@ -1,20 +1,15 @@ -import asyncio -import copy -import random -from threading import Lock -from typing import Any, List +from typing import List -import ray import torch -from coati.experience_buffer import ExperienceBuffer from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch from coati.experience_maker.base import Experience + # from torch.multiprocessing import Queue from ray.util.queue import Queue class DetachedReplayBuffer: - ''' + """ Detached replay buffer. Share Experience across workers on the same node. Therefore, a trainer node is expected to have only one instance. It is ExperienceMakerHolder's duty to call append(exp) method, remotely. @@ -24,7 +19,7 @@ class DetachedReplayBuffer: tp_world_size: Number of workers in the same tp group limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0. cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True. - ''' + """ def __init__(self, sample_batch_size: int, limit: int = 0) -> None: self.sample_batch_size = sample_batch_size @@ -34,23 +29,23 @@ def __init__(self, sample_batch_size: int, limit: int = 0) -> None: @torch.no_grad() def append(self, experience: Experience) -> None: - ''' + """ Expected to be called remotely. - ''' + """ items = split_experience_batch(experience) self.extend(items) @torch.no_grad() def extend(self, items: List[BufferItem]) -> None: - ''' + """ Expected to be called remotely. - ''' + """ self.batch_collector.extend(items) while len(self.batch_collector) >= self.sample_batch_size: - items = self.batch_collector[:self.sample_batch_size] + items = self.batch_collector[: self.sample_batch_size] experience = make_experience_batch(items) self.items.put(experience, block=True) - self.batch_collector = self.batch_collector[self.sample_batch_size:] + self.batch_collector = self.batch_collector[self.sample_batch_size :] def clear(self) -> None: # self.items.close() diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py index 90399781187a..fcf0a472df9e 100644 --- a/applications/Chat/coati/ray/detached_trainer_base.py +++ b/applications/Chat/coati/ray/detached_trainer_base.py @@ -1,6 +1,6 @@ import os from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, List import ray import torch @@ -15,7 +15,7 @@ class DetachedTrainer(ABC): - ''' + """ Base class for detached rlhf trainers. 'detach' means that the experience maker is detached compared to a normal Trainer. Please set name attribute during init: @@ -28,15 +28,17 @@ class DetachedTrainer(ABC): callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating - ''' - - def __init__(self, - experience_maker_holder_name_list: List[str], - train_batch_size: int = 8, - buffer_limit: int = 0, - dataloader_pin_memory: bool = True, - callbacks: List[TrainerCallback] = [], - debug: bool = False) -> None: + """ + + def __init__( + self, + experience_maker_holder_name_list: List[str], + train_batch_size: int = 8, + buffer_limit: int = 0, + dataloader_pin_memory: bool = True, + callbacks: List[TrainerCallback] = [], + debug: bool = False, + ) -> None: super().__init__() self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit) self.dataloader_pin_memory = dataloader_pin_memory @@ -67,18 +69,16 @@ def training_step(self, experience: Experience) -> Dict[str, Any]: def _learn(self, update_steps: int, train_epochs: int) -> None: data = [] # warmup - pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0()) + pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0()) self._on_epoch_start(0) self._learn_epoch(pbar, data) self._on_epoch_end(0) # item is already a batch - dataloader = DataLoader(data, - batch_size=1, - shuffle=True, - pin_memory=self.dataloader_pin_memory, - collate_fn=lambda x: x[0]) + dataloader = DataLoader( + data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0] + ) for epoch in range(1, train_epochs): - pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0()) + pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0()) self._on_epoch_start(epoch) self._learn_epoch(pbar, data) self._on_epoch_end(epoch) @@ -104,7 +104,7 @@ def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None: def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None: self._on_fit_start() - for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()): + for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()): self._on_episode_start(i) self._learn(update_steps, train_epochs) self._on_update_start() diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py index 2f2aa0e29579..ef84a1ddba48 100644 --- a/applications/Chat/coati/ray/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/detached_trainer_ppo.py @@ -1,12 +1,11 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Tuple import ray import torch -from coati.experience_maker import Experience, NaiveExperienceMaker +from coati.experience_maker import Experience from coati.models.base import Actor, Critic from coati.models.loss import PolicyLoss, ValueLoss -from coati.trainer.callbacks import Callback -from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy +from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy from torch.optim import Adam from colossalai.nn.optimizer import HybridAdam @@ -14,27 +13,14 @@ from .callbacks import TrainerCallback, TrainerPerformanceEvaluator from .detached_trainer_base import DetachedTrainer from .lora_constructor import LoRAConstructor -from .utils import ( - get_actor_from_args, - get_critic_from_args, - get_model_numel, - get_rank, - get_strategy_from_args, - is_rank_0, - set_dist_env, - state_dict_to, -) +from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to -@ray.remote(concurrency_groups={ - "buffer_length": 1, - "buffer_append": 1, - "buffer_sample": 1, - "model_io": 1, - "compute": 1 -}) +@ray.remote( + concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1} +) class DetachedPPOTrainer(DetachedTrainer): - ''' + """ Detached Trainer for PPO algorithm Args: strategy (Strategy): the strategy to use for training @@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer): dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader callbacks (List[Callback], defaults to []): the callbacks to call during training process generate_kwargs (dict, optional): the kwargs to use while model generating - ''' + """ def __init__( self, @@ -92,21 +78,24 @@ def __init__( self.actor_optim = Adam(self.actor.parameters(), lr=1e-7) self.critic_optim = Adam(self.critic.parameters(), lr=1e-7) - (self.actor, self.actor_optim), (self.critic, self.critic_optim) = \ - self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim)) + (self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare( + (self.actor, self.actor_optim), (self.critic, self.critic_optim) + ) # configure trainer self.actor_loss_fn = PolicyLoss(eps_clip) self.critic_loss_fn = ValueLoss(value_clip) - super().__init__(experience_maker_holder_name_list, - train_batch_size=train_batch_size, - buffer_limit=buffer_limit, - dataloader_pin_memory=dataloader_pin_memory, - callbacks=callbacks, - debug=debug) + super().__init__( + experience_maker_holder_name_list, + train_batch_size=train_batch_size, + buffer_limit=buffer_limit, + dataloader_pin_memory=dataloader_pin_memory, + callbacks=callbacks, + debug=debug, + ) if self._debug: - print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}') + print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}") self._update_lora_weights = update_lora_weights @@ -115,7 +104,7 @@ def __init__( def _update_remote_makers(self, fully_update: bool = False, **config): # TODO: balance duties if not fully_update: - config['requires_grad_only'] = True + config["requires_grad_only"] = True self.update_target_holder_list() # mark start, ensure order tasks = [] @@ -131,7 +120,9 @@ def _update_remote_makers(self, fully_update: bool = False, **config): target_holder.update_experience_maker.remote( new_actor_state_dict=state_dict_shard, new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor), - fully_update=fully_update)) + fully_update=fully_update, + ) + ) # sending loop for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config): for target_holder in self.target_holder_list: @@ -139,7 +130,9 @@ def _update_remote_makers(self, fully_update: bool = False, **config): target_holder.update_experience_maker.remote( new_critic_state_dict=state_dict_shard, new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic), - fully_update=fully_update)) + fully_update=fully_update, + ) + ) ray.get(tasks) # mark end for target_holder in self.target_holder_list: @@ -152,26 +145,24 @@ def training_step(self, experience: Experience) -> Dict[str, float]: num_actions = experience.action_mask.size(1) action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) - actor_loss = self.actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) + actor_loss = self.actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) self.strategy.backward(actor_loss, self.actor, self.actor_optim) self.strategy.optimizer_step(self.actor_optim) self.actor_optim.zero_grad() - values = self.critic(experience.sequences, - action_mask=experience.action_mask, - attention_mask=experience.attention_mask) - critic_loss = self.critic_loss_fn(values, - experience.values, - experience.reward, - action_mask=experience.action_mask) + values = self.critic( + experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask + ) + critic_loss = self.critic_loss_fn( + values, experience.values, experience.reward, action_mask=experience.action_mask + ) self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.optimizer_step(self.critic_optim) self.critic_optim.zero_grad() - return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} + return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()} def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None: self.strategy.save_model(self.actor, path, only_rank0) diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py index 13314bdafd5f..4d290f4aba88 100644 --- a/applications/Chat/coati/ray/experience_maker_holder.py +++ b/applications/Chat/coati/ray/experience_maker_holder.py @@ -1,53 +1,49 @@ import os import time import tracemalloc -from copy import deepcopy from threading import Lock -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union import ray import torch -import torch.nn as nn -from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch -from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker +from coati.experience_buffer.utils import split_experience_batch +from coati.experience_maker import Experience, NaiveExperienceMaker from coati.models.base import Actor, Critic, RewardModel -from coati.trainer.callbacks import Callback from coati.trainer.strategies import Strategy -from coati.trainer.strategies.sampler import DistributedSampler -from ray.exceptions import GetTimeoutError from torch import Tensor from tqdm import tqdm from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback from .lora_constructor import LoRAConstructor -from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to +from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to @ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) class ExperienceMakerHolder: - ''' + """ Args: detached_trainer_name_list: str list to get ray actor handles strategy: kl_coef: the coefficient of kl divergence loss sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models. - ''' + """ def __init__( - self, - detached_trainer_name_list: List[str], - strategy_fn: Callable[[], Strategy], + self, + detached_trainer_name_list: List[str], + strategy_fn: Callable[[], Strategy], # a function returns (actor, critic, reward_model, initial_model) - model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], - env_info: Dict[str, str] = None, - sync_models_from_trainers: bool = False, - buffer_cpu_offload: bool = True, - kl_coef: float = 0.1, - callbacks: List[MakerCallback] = [], - eval_performance: bool = False, - debug: bool = False, - update_lora_weights: bool = False, - **generate_kwargs): + model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]], + env_info: Dict[str, str] = None, + sync_models_from_trainers: bool = False, + buffer_cpu_offload: bool = True, + kl_coef: float = 0.1, + callbacks: List[MakerCallback] = [], + eval_performance: bool = False, + debug: bool = False, + update_lora_weights: bool = False, + **generate_kwargs, + ): # set environment variables if env_info: set_dist_env(env_info=env_info) @@ -66,8 +62,9 @@ def __init__( critic_numel = get_model_numel(critic) initial_model_numel = get_model_numel(initial_model) reward_model_numel = get_model_numel(reward_model) - evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel, - reward_model_numel) + evaluator = ExperienceMakerPerformanceEvaluator( + actor_numel, critic_numel, initial_model_numel, reward_model_numel + ) callbacks = callbacks + [evaluator] actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model) @@ -89,9 +86,9 @@ def __init__( self._target_idx = 0 if self._debug: - print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}') + print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}") if not self._is_fully_initialized: - print(f'[maker{get_rank()}] Waiting for INIT') + print(f"[maker{get_rank()}] Waiting for INIT") def _get_ready(self): while not self._fully_initialized(): @@ -136,7 +133,7 @@ def _inference_step(self, batch) -> None: self._on_make_experience_end(experience) self._on_send_start() if self.buffer_cpu_offload: - experience.to_device('cpu') + experience.to_device("cpu") self._send_items(experience) self._on_send_end() self._on_batch_end() @@ -155,7 +152,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1 if num_steps > 0: # ignore num epochs it = iter(dataloader) - for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()): + for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()): try: batch = next(it) except StopIteration: @@ -163,7 +160,7 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1 batch = next(it) self._inference_step(batch) else: - with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar: + with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar: for _ in range(num_epochs): for batch in dataloader: self._inference_step(batch) @@ -171,22 +168,24 @@ def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1 self._on_loop_end() @ray.method(concurrency_group="model_io") - def update_experience_maker(self, - new_actor_state_dict: Dict[str, Any] = None, - new_actor_lora_config_dict: Dict[str, Any] = None, - new_critic_state_dict: Dict[str, Any] = None, - new_critic_lora_config_dict: Dict[str, Any] = None, - fully_update: bool = False, - chunk_start: bool = None, - chunk_end: bool = None): - ''' - called by trainer - chunk_start: Set True at the first call. Before sending state_dict calls - chunk_end: Set True at the last call. After sending state_dict calls. - fully_update: Set True if you want to sync models when initializing - - TODO: load_state_dict integrate with model-sharding strategy - ''' + def update_experience_maker( + self, + new_actor_state_dict: Dict[str, Any] = None, + new_actor_lora_config_dict: Dict[str, Any] = None, + new_critic_state_dict: Dict[str, Any] = None, + new_critic_lora_config_dict: Dict[str, Any] = None, + fully_update: bool = False, + chunk_start: bool = None, + chunk_end: bool = None, + ): + """ + called by trainer + chunk_start: Set True at the first call. Before sending state_dict calls + chunk_end: Set True at the last call. After sending state_dict calls. + fully_update: Set True if you want to sync models when initializing + + TODO: load_state_dict integrate with model-sharding strategy + """ _watch_memory = self._debug if chunk_start: if self._debug: @@ -202,18 +201,22 @@ def update_experience_maker(self, else: new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device()) state_dict_increase = self.actor_lora_constructor.reconstruct_increase( - new_actor_state_dict, new_actor_lora_config_dict) + new_actor_state_dict, new_actor_lora_config_dict + ) self.actor_lora_constructor.load_state_dict_increase( - self.experience_maker.actor.model, state_dict_increase) + self.experience_maker.actor.model, state_dict_increase + ) if new_critic_state_dict is not None: if not self._update_lora_weights or fully_update: self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False) else: new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device()) state_dict_increase = self.critic_lora_constructor.reconstruct_increase( - new_critic_state_dict, new_critic_lora_config_dict) + new_critic_state_dict, new_critic_lora_config_dict + ) self.critic_lora_constructor.load_state_dict_increase( - self.experience_maker.critic, state_dict_increase) + self.experience_maker.critic, state_dict_increase + ) # the lock must be released after both actor and critic being updated if chunk_end: @@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None: origin_model = actor.model new_kwargs = {**generate_kwargs} # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation + if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"): + new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation - if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'): - new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation + if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"): + new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation return new_kwargs diff --git a/applications/Chat/coati/ray/lora_constructor.py b/applications/Chat/coati/ray/lora_constructor.py index a98545d4d751..8e9f78700e29 100644 --- a/applications/Chat/coati/ray/lora_constructor.py +++ b/applications/Chat/coati/ray/lora_constructor.py @@ -1,11 +1,9 @@ from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict -import torch import torch.nn as nn from coati.models.lora import LoraLinear -from loralib.layers import LoRALayer @dataclass @@ -17,7 +15,7 @@ class LoRAConfig: class LoRAConstructor: - ''' + """ Tools for reconstructing a model from a remote LoRA model. (Transferring only LoRA data costs much less!) Usage: @@ -36,7 +34,7 @@ class LoRAConstructor: Step 5 (Receiver): load_state_dict_increase() - ''' + """ def __init__(self): self.lora_config_dict = None @@ -45,10 +43,10 @@ def register_lora_config(self, lora_config_dict: Dict[str, Any]): self.lora_config_dict = lora_config_dict def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]): - ''' - xxx.lora_A, xxx.lora_B -->> xxx.weight - Warning: the xxx.weight here is the increment actually. - ''' + """ + xxx.lora_A, xxx.lora_B -->> xxx.weight + Warning: the xxx.weight here is the increment actually. + """ if lora_config_dict is not None: self.register_lora_config(lora_config_dict) @@ -56,24 +54,25 @@ def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict config_iter = iter(self.lora_config_dict.items()) lora_A, lora_B, layer_prefix = None, None, None for k, v in state_dict_lora.items(): - if k.rpartition('.')[-1] == 'lora_A': + if k.rpartition(".")[-1] == "lora_A": lora_A = v - layer_prefix = k.rpartition('.')[0] - elif k.rpartition('.')[-1] == 'lora_B': - assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair" + layer_prefix = k.rpartition(".")[0] + elif k.rpartition(".")[-1] == "lora_B": + assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair" layer_prefix_2, config = next(config_iter) assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair" lora_B = v weight_data_increase = self._compute(lora_A, lora_B, config) - state_dict_increase[layer_prefix + '.weight'] = weight_data_increase + state_dict_increase[layer_prefix + ".weight"] = weight_data_increase lora_A, lora_B, layer_prefix = None, None, None else: - raise ValueError('unexpected key') + raise ValueError("unexpected key") return state_dict_increase def _compute(self, lora_A, lora_B, config=LoRAConfig()): def T(w): return w.T if config.fan_in_fan_out else w + if config.r > 0: scaling = config.lora_alpha / config.r weight_data_increase = T(lora_B @ lora_A) * scaling @@ -81,21 +80,21 @@ def T(w): return 0 def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]): - ''' + """ The final reconstruction step - ''' + """ # naive approach model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False) @staticmethod def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False): - ''' + """ if keep_non_lora, also return non_lora state_dict - ''' + """ state_dict_lora = OrderedDict() state_dict_non_lora = OrderedDict() for k, v in state_dict.items(): - if 'lora_A' in k or 'lora_B' in k: + if "lora_A" in k or "lora_B" in k: state_dict_lora[k] = v elif keep_non_lora: state_dict_non_lora[k] = v @@ -106,17 +105,19 @@ def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False): @staticmethod def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]: - ''' + """ extract LoraLinear model. return OrderedDict(): name -> LoRAConfig - ''' + """ lora_config_dict = OrderedDict() for name, child in model.named_modules(): if isinstance(child, LoraLinear): - lora_config_dict[name] = LoRAConfig(r=child.r, - lora_alpha=child.lora_alpha, - lora_dropout=child.lora_dropout, - fan_in_fan_out=child.fan_in_fan_out) + lora_config_dict[name] = LoRAConfig( + r=child.r, + lora_alpha=child.lora_alpha, + lora_dropout=child.lora_dropout, + fan_in_fan_out=child.fan_in_fan_out, + ) return lora_config_dict diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py index 391ffe7a91a9..036dd145dddb 100644 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/Chat/coati/ray/utils.py @@ -1,6 +1,6 @@ import os from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict import torch import torch.distributed as dist @@ -10,7 +10,7 @@ from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy -from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer def is_rank_0() -> bool: @@ -26,13 +26,13 @@ def get_world_size() -> int: def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): - if model == 'gpt2': + if model == "gpt2": actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank) - elif model == 'bloom': + elif model == "bloom": actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank) - elif model == 'opt': + elif model == "opt": actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank) - elif model == 'llama': + elif model == "llama": actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank) else: raise ValueError(f'Unsupported actor model "{model}"') @@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): - if model == 'gpt2': + if model == "gpt2": critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) - elif model == 'bloom': + elif model == "bloom": critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) - elif model == 'opt': + elif model == "opt": critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) - elif model == 'llama': + elif model == "llama": critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True) else: raise ValueError(f'Unsupported reward model "{model}"') @@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r def get_reward_model_from_args(model: str, pretrained: str = None, config=None): - if model == 'gpt2': + if model == "gpt2": reward_model = GPTRM(pretrained=pretrained, config=config) - elif model == 'bloom': + elif model == "bloom": reward_model = BLOOMRM(pretrained=pretrained, config=config) - elif model == 'opt': + elif model == "opt": reward_model = OPTRM(pretrained=pretrained, config=config) - elif model == 'llama': + elif model == "llama": reward_model = LlamaRM(pretrained=pretrained, config=config) else: raise ValueError(f'Unsupported reward model "{model}"') @@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None): def get_strategy_from_args(strategy: str): - if strategy == 'ddp': + if strategy == "ddp": strategy_ = DDPStrategy() - elif strategy == 'colossalai_gemini': - strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) - elif strategy == 'colossalai_zero2': - strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda') - elif strategy == 'colossalai_gemini_cpu': - strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5) - elif strategy == 'colossalai_zero2_cpu': - strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu') + elif strategy == "colossalai_gemini": + strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif strategy == "colossalai_zero2": + strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda") + elif strategy == "colossalai_gemini_cpu": + strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5) + elif strategy == "colossalai_zero2_cpu": + strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{strategy}"') return strategy_ def get_tokenizer_from_args(model: str, **kwargs): - if model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - elif model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') - elif model == 'opt': + if model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + elif model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") + elif model == "opt": tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") - elif model == 'llama': + elif model == "llama": pretrain_path = kwargs["pretrain"] tokenizer = AutoTokenizer.from_pretrained(pretrain_path) else: @@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs): def set_dist_env(env_info: Dict[str, str]): - os.environ["RANK"] = env_info['rank'] - os.environ["LOCAL_RANK"] = env_info['local_rank'] - os.environ["WORLD_SIZE"] = env_info['world_size'] - os.environ['MASTER_PORT'] = env_info['master_port'] - os.environ['MASTER_ADDR'] = env_info['master_addr'] + os.environ["RANK"] = env_info["rank"] + os.environ["LOCAL_RANK"] = env_info["local_rank"] + os.environ["WORLD_SIZE"] = env_info["world_size"] + os.environ["MASTER_PORT"] = env_info["master_port"] + os.environ["MASTER_ADDR"] = env_info["master_addr"] def get_model_numel(model: nn.Module) -> int: @@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i return target_receivers -def state_dict_to(state_dict: Dict[str, Any], - dtype: torch.dtype = torch.float16, - device: torch.device = torch.device('cpu')): - ''' - keep state_dict intact - ''' +def state_dict_to( + state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu") +): + """ + keep state_dict intact + """ new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k] = v.to(dtype=dtype, device=device) diff --git a/applications/Chat/coati/trainer/__init__.py b/applications/Chat/coati/trainer/__init__.py index 86142361f3ff..4be5d27f93b1 100644 --- a/applications/Chat/coati/trainer/__init__.py +++ b/applications/Chat/coati/trainer/__init__.py @@ -3,8 +3,4 @@ from .rm import RewardModelTrainer from .sft import SFTTrainer -__all__ = [ - 'SLTrainer', 'OnPolicyTrainer', - 'RewardModelTrainer', 'SFTTrainer', - 'PPOTrainer' -] +__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"] diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py index 0629c9c00cca..ca450edee0c3 100644 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/Chat/coati/trainer/base.py @@ -68,12 +68,14 @@ class OnPolicyTrainer(ABC): callbacks (List[Callback], defaults to []): the callbacks to call during training process """ - def __init__(self, - strategy: Strategy, - data_buffer: NaiveExperienceBuffer, - sample_buffer: bool, - dataloader_pin_memory: bool, - callbacks: List[Callback] = []) -> None: + def __init__( + self, + strategy: Strategy, + data_buffer: NaiveExperienceBuffer, + sample_buffer: bool, + dataloader_pin_memory: bool, + callbacks: List[Callback] = [], + ) -> None: super().__init__() self.strategy = strategy self.data_buffer = data_buffer diff --git a/applications/Chat/coati/trainer/callbacks/__init__.py b/applications/Chat/coati/trainer/callbacks/__init__.py index 9ed0ee6f7640..29c8c4f00a5c 100644 --- a/applications/Chat/coati/trainer/callbacks/__init__.py +++ b/applications/Chat/coati/trainer/callbacks/__init__.py @@ -2,4 +2,4 @@ from .performance_evaluator import PerformanceEvaluator from .save_checkpoint import SaveCheckpoint -__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint'] +__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"] diff --git a/applications/Chat/coati/trainer/callbacks/base.py b/applications/Chat/coati/trainer/callbacks/base.py index f5616048855b..d5181175b324 100644 --- a/applications/Chat/coati/trainer/callbacks/base.py +++ b/applications/Chat/coati/trainer/callbacks/base.py @@ -5,7 +5,7 @@ class Callback(ABC): """ - Base callback class. It defines the interface for callbacks. + Base callback class. It defines the interface for callbacks. """ def on_fit_start(self) -> None: diff --git a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py index 9b44dafa7eaa..c2eda92cc165 100644 --- a/applications/Chat/coati/trainer/callbacks/performance_evaluator.py +++ b/applications/Chat/coati/trainer/callbacks/performance_evaluator.py @@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None: def divide(x: float, y: float) -> float: if y == 0: - return float('inf') - elif y == float('inf'): - return float('nan') + return float("inf") + elif y == float("inf"): + return float("nan") return x / y @@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float: class Timer: - def __init__(self) -> None: self.start_time: Optional[float] = None - self.duration: float = 0. + self.duration: float = 0.0 def start(self) -> None: self.start_time = time() @@ -52,7 +51,7 @@ def end(self) -> None: self.start_time = None def reset(self) -> None: - self.duration = 0. + self.duration = 0.0 class PerformanceEvaluator(Callback): @@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback): ignore_episodes: The number of episodes to ignore when calculating the performance. """ - def __init__(self, - actor_num_params: int, - critic_num_params: int, - initial_model_num_params: int, - reward_model_num_params: int, - enable_grad_checkpoint: bool = False, - ignore_episodes: int = 0) -> None: + def __init__( + self, + actor_num_params: int, + critic_num_params: int, + initial_model_num_params: int, + reward_model_num_params: int, + enable_grad_checkpoint: bool = False, + ignore_episodes: int = 0, + ) -> None: super().__init__() self.world_size = get_world_size() self.actor_num_params = actor_num_params @@ -155,8 +156,9 @@ def on_fit_end(self) -> None: avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size) avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size) - avg_make_experience_throughput = self.make_experience_num_samples * \ - self.world_size / (avg_make_experience_duration + 1e-12) + avg_make_experience_throughput = ( + self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12) + ) avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12) avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12) @@ -171,13 +173,11 @@ def on_fit_end(self) -> None: learn_time_per_sample = divide(avg_learn_duration, num_effective_samples) print_rank_0( - f'Performance summary:\n' - + f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n' - - + f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n' - + f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n' - + f'Overall time per sample: {overall_time_per_sample:.2f} s\n' - + f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n' - - + f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%' + f"Performance summary:\n" + + f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n" + + f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n" + + f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n" + + f"Overall time per sample: {overall_time_per_sample:.2f} s\n" + + f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n" + + f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%" ) diff --git a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py index f0d77a191a88..0d70b6c53073 100644 --- a/applications/Chat/coati/trainer/callbacks/save_checkpoint.py +++ b/applications/Chat/coati/trainer/callbacks/save_checkpoint.py @@ -36,34 +36,35 @@ class SaveCheckpoint(Callback): """ - def __init__(self, - path: str, - interval: int, - strategy: Strategy, - actor: nn.Module = None, - critic: nn.Module = None, - actor_optim: Optimizer = None, - critic_optim: Optimizer = None) -> None: + def __init__( + self, + path: str, + interval: int, + strategy: Strategy, + actor: nn.Module = None, + critic: nn.Module = None, + actor_optim: Optimizer = None, + critic_optim: Optimizer = None, + ) -> None: super().__init__() - self.path = os.path.join(path, 'checkpoint') + self.path = os.path.join(path, "checkpoint") self.interval = interval self.strategy = strategy - self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]} + self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]} def on_episode_end(self, episode: int) -> None: if (episode + 1) % self.interval != 0: return - base_path = os.path.join(self.path, f'episode_{episode}') + base_path = os.path.join(self.path, f"episode_{episode}") if not os.path.exists(base_path): os.makedirs(base_path) for model in self.model_dict.keys(): - # save model if self.model_dict[model][0] is None: # saving only optimizer states is meaningless, so it would be skipped continue - model_path = os.path.join(base_path, f'{model}.pt') + model_path = os.path.join(base_path, f"{model}.pt") self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True) # save optimizer @@ -71,5 +72,5 @@ def on_episode_end(self, episode: int) -> None: continue only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy)) rank = 0 if is_rank_0() else dist.get_rank() - optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt') + optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt") self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index ef625a1c1b3d..6f255a935d91 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -8,7 +8,7 @@ from coati.models.utils import calc_action_log_probs from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data import DistributedSampler from tqdm import tqdm from colossalai.utils import get_current_device @@ -24,11 +24,11 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto hf_model = get_base_model(unwrapper_model) new_kwargs = {**generate_kwargs} # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation + if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"): + new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation - if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'): - new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation + if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"): + new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation return new_kwargs @@ -60,38 +60,34 @@ class PPOTrainer(OnPolicyTrainer): generate_kwargs (dict, optional): the kwargs to use while model generating """ - def __init__(self, - strategy: Strategy, - actor: Actor, - critic: Critic, - reward_model: nn.Module, - initial_model: Actor, - actor_optim: Optimizer, - critic_optim: Optimizer, - kl_coef: float = 0.1, - ptx_coef: float = 0.9, - train_batch_size: int = 8, - buffer_limit: int = 0, - buffer_cpu_offload: bool = True, - eps_clip: float = 0.2, - vf_coef: float = 1.0, - value_clip: float = 0.4, - sample_buffer: bool = False, - dataloader_pin_memory: bool = True, - offload_inference_models: bool = True, - callbacks: List[Callback] = [], - **generate_kwargs - ) -> None: + def __init__( + self, + strategy: Strategy, + actor: Actor, + critic: Critic, + reward_model: nn.Module, + initial_model: Actor, + actor_optim: Optimizer, + critic_optim: Optimizer, + kl_coef: float = 0.1, + ptx_coef: float = 0.9, + train_batch_size: int = 8, + buffer_limit: int = 0, + buffer_cpu_offload: bool = True, + eps_clip: float = 0.2, + vf_coef: float = 1.0, + value_clip: float = 0.4, + sample_buffer: bool = False, + dataloader_pin_memory: bool = True, + offload_inference_models: bool = True, + callbacks: List[Callback] = [], + **generate_kwargs, + ) -> None: if isinstance(strategy, GeminiStrategy): - assert not offload_inference_models, \ - "GeminiPlugin is not compatible with manual model.to('cpu')" + assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')" data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) - super().__init__( - strategy, data_buffer, - sample_buffer, dataloader_pin_memory, - callbacks - ) + super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks) self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) @@ -130,18 +126,16 @@ def _training_step(self, experience: Experience) -> Dict[str, float]: num_actions = experience.action_mask.size(1) actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask) action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions) - actor_loss = self.actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) + actor_loss = self.actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) # ptx loss if self.ptx_coef != 0: batch = self.pretrain_dataloader.next() batch = to_device(batch, self.device) - ptx_log_probs = self.actor(batch['input_ids'], - attention_mask=batch['attention_mask'])['logits'] - ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels']) + ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"] + ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"]) actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) self.strategy.backward(actor_loss, self.actor, self.actor_optim) @@ -149,24 +143,23 @@ def _training_step(self, experience: Experience) -> Dict[str, float]: self.actor_optim.zero_grad() # value loss - values = self.critic(experience.sequences, - action_mask=experience.action_mask, - attention_mask=experience.attention_mask) - critic_loss = self.critic_loss_fn(values, - experience.values, - experience.reward, - action_mask=experience.action_mask) + values = self.critic( + experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask + ) + critic_loss = self.critic_loss_fn( + values, experience.values, experience.reward, action_mask=experience.action_mask + ) critic_loss = critic_loss * self.vf_coef self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.optimizer_step(self.critic_optim) self.critic_optim.zero_grad() - return {'reward': experience.reward.mean().item()} + return {"reward": experience.reward.mean().item()} def _learn(self, update_step: int): if self.offload_inference_models: - self.experience_maker.initial_model.to('cpu') - self.experience_maker.reward_model.to('cpu') + self.experience_maker.initial_model.to("cpu") + self.experience_maker.reward_model.to("cpu") # buffer may be empty at first, we should rebuild at each training if self.sample_buffer: @@ -178,11 +171,7 @@ def _learn(self, update_step: int): else: if isinstance(self.dataloader.sampler, DistributedSampler): self.dataloader.sampler.set_epoch(update_step) - pbar = tqdm( - self.dataloader, - desc=f'Train epoch [{update_step + 1}]', - disable=not is_rank_0() - ) + pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0()) for experience in pbar: self._on_learn_batch_start() experience.to_device(self.device) diff --git a/applications/Chat/coati/trainer/rm.py b/applications/Chat/coati/trainer/rm.py index 54a5d0f40dea..a5d6974b3238 100644 --- a/applications/Chat/coati/trainer/rm.py +++ b/applications/Chat/coati/trainer/rm.py @@ -62,18 +62,15 @@ def _eval(self, epoch): if is_rank_0(): log = pd.DataFrame( - [[(epoch + 1) * len(self.train_dataloader), - self.loss.item(), self.dist, self.acc]], - columns=['step', 'loss', 'dist', 'acc'] + [[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]], + columns=["step", "loss", "dist", "acc"], ) - log.to_csv('log.csv', mode='a', header=False, index=False) + log.to_csv("log.csv", mode="a", header=False, index=False) def _train(self, epoch): self.model.train() step_bar = tqdm.trange( - len(self.train_dataloader), - desc='Train step of epoch %d' % epoch, - disable=not is_rank_0() + len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0() ) cnt = 0 for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader: @@ -93,10 +90,7 @@ def _train(self, epoch): step_bar.update() step_bar.close() - def _before_fit(self, - train_dataloader: DataLoader, - valid_dataloader: DataLoader, - eval_dataloader: DataLoader): + def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader): """ Args: train_dataloader (DataLoader): the dataloader to use for training @@ -104,7 +98,7 @@ def _before_fit(self, eval_dataloader (DataLoader): the dataloader to use for evaluation """ super()._before_fit() - self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") self.train_dataloader = train_dataloader self.valid_dataloader = valid_dataloader diff --git a/applications/Chat/coati/trainer/sft.py b/applications/Chat/coati/trainer/sft.py index e4d0a970740d..8deefc2c484e 100644 --- a/applications/Chat/coati/trainer/sft.py +++ b/applications/Chat/coati/trainer/sft.py @@ -39,8 +39,9 @@ def __init__( accumulation_steps: int = 8, ) -> None: if accumulation_steps > 1: - assert not isinstance(strategy, GeminiStrategy), \ - "Accumulation steps are not supported in stage 3 of ColossalAI" + assert not isinstance( + strategy, GeminiStrategy + ), "Accumulation steps are not supported in stage 3 of ColossalAI" super().__init__(strategy, max_epochs, model, optim) @@ -50,15 +51,11 @@ def __init__( def _train(self, epoch: int): self.model.train() for batch_id, batch in enumerate(self.train_dataloader): - batch = to_device(batch, torch.cuda.current_device()) if "attention_mask" in batch: - outputs = self.model(batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"]) + outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) else: - outputs = self.model(batch["input_ids"], - labels=batch["labels"]) + outputs = self.model(batch["input_ids"], labels=batch["labels"]) loss = outputs.loss loss = loss / self.accumulation_steps @@ -73,12 +70,14 @@ def _train(self, epoch: int): self.optimizer.zero_grad() self.scheduler.step() if is_rank_0() and self.use_wandb: - wandb.log({ - "loss": self.total_loss / self.accumulation_steps, - "lr": self.scheduler.get_last_lr()[0], - "epoch": epoch, - "batch_id": batch_id - }) + wandb.log( + { + "loss": self.total_loss / self.accumulation_steps, + "lr": self.scheduler.get_last_lr()[0], + "epoch": epoch, + "batch_id": batch_id, + } + ) self.total_loss = 0 self.step_bar.update() @@ -89,9 +88,9 @@ def _eval(self, epoch: int): loss_sum, num_seen = 0, 0 for batch in self.eval_dataloader: batch = to_device(batch, torch.cuda.current_device()) - outputs = self.model(batch["input_ids"], - attention_mask=batch["attention_mask"], - labels=batch["labels"]) + outputs = self.model( + batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"] + ) loss = outputs.loss loss_sum += loss.item() @@ -99,13 +98,15 @@ def _eval(self, epoch: int): loss_mean = loss_sum / num_seen if dist.get_rank() == 0: - self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}') + self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}") - def _before_fit(self, - train_dataloader: DataLoader, - eval_dataloader: Optional[DataLoader] = None, - logger: Optional[DistributedLogger] = None, - use_wandb: bool = False): + def _before_fit( + self, + train_dataloader: DataLoader, + eval_dataloader: Optional[DataLoader] = None, + logger: Optional[DistributedLogger] = None, + use_wandb: bool = False, + ): """ Args: train_dataloader: the dataloader to use for training @@ -124,6 +125,6 @@ def _before_fit(self, self.no_epoch_bar = True self.step_bar = tqdm.trange( len(self.train_dataloader) // self.accumulation_steps * self.max_epochs, - desc=f'steps', - disable=not is_rank_0() + desc=f"steps", + disable=not is_rank_0(), ) diff --git a/applications/Chat/coati/trainer/strategies/__init__.py b/applications/Chat/coati/trainer/strategies/__init__.py index b49a2c742db3..521dcb5855b1 100644 --- a/applications/Chat/coati/trainer/strategies/__init__.py +++ b/applications/Chat/coati/trainer/strategies/__init__.py @@ -2,7 +2,4 @@ from .colossalai import GeminiStrategy, LowLevelZeroStrategy from .ddp import DDPStrategy -__all__ = [ - 'Strategy', 'DDPStrategy', - 'LowLevelZeroStrategy', 'GeminiStrategy' -] +__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"] diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index c20b2b16e396..303d4bc220a6 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -19,7 +19,7 @@ class Strategy(ABC): """ - Base class for training strategies. + Base class for training strategies. """ def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None: @@ -83,16 +83,18 @@ def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _Boo rets.append((model, optimizer)) elif isinstance(arg, Dict): model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg) - boost_result = dict(model=model, - optimizer=optimizer, - criterion=criterion, - dataloader=dataloader, - lr_scheduler=lr_scheduler) + boost_result = dict( + model=model, + optimizer=optimizer, + criterion=criterion, + dataloader=dataloader, + lr_scheduler=lr_scheduler, + ) # remove None values boost_result = {key: value for key, value in boost_result.items() if value is not None} rets.append(boost_result) else: - raise RuntimeError(f'Type {type(arg)} is not supported') + raise RuntimeError(f"Type {type(arg)} is not supported") return rets[0] if len(rets) == 1 else rets @@ -125,11 +127,9 @@ def setup_sampler(self, dataset) -> DistributedSampler: return DistributedSampler(dataset, 1, 0) @abstractmethod - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + def save_pretrained( + self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None + ) -> None: pass @abstractmethod diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index fa55f97ad661..4706f9699c91 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -42,27 +42,27 @@ class LowLevelZeroStrategy(DDPStrategy): """ - def __init__(self, - stage: int = 2, - precision: str = 'fp16', - seed: int = 42, - placement_policy: str = 'cuda', - reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 - overlap_communication: bool = True, # only for stage 1&2 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0 - ) -> None: - + def __init__( + self, + stage: int = 2, + precision: str = "fp16", + seed: int = 42, + placement_policy: str = "cuda", + reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 + overlap_communication: bool = True, # only for stage 1&2 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + ) -> None: assert stage in (1, 2), f'Unsupported stage "{stage}"' - assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' - assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' + assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"' + assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"' plugin_initializer = lambda: LowLevelZeroPlugin( # zero_config @@ -71,7 +71,7 @@ def __init__(self, # zero_optim_config reduce_bucket_size_in_m=reduce_bucket_size, overlap_communication=overlap_communication, - cpu_offload=(placement_policy == 'cpu'), + cpu_offload=(placement_policy == "cpu"), # optim_config initial_scale=initial_scale, growth_factor=growth_factor, @@ -81,14 +81,15 @@ def __init__(self, min_scale=min_scale, max_scale=max_scale, max_norm=max_norm, - norm_type=norm_type + norm_type=norm_type, ) super().__init__(seed, plugin_initializer) def _post_init(self) -> None: - assert isinstance(self.plugin, LowLevelZeroPlugin), \ - f'{type(self).__name__}\'s plugin is not initialized properly.' + assert isinstance( + self.plugin, LowLevelZeroPlugin + ), f"{type(self).__name__}'s plugin is not initialized properly." def setup_distributed(self) -> None: colossalai.launch_from_torch({}, seed=self.seed) @@ -131,45 +132,45 @@ class GeminiStrategy(DDPStrategy): """ - def __init__(self, - seed: int = 42, - shard_init: bool = False, # only for stage 3 - placement_policy: str = 'cuda', - pin_memory: bool = True, # only for stage 3 - force_outputs_fp32: bool = False, # only for stage 3 - search_range_m: int = 32, # only for stage 3 - hidden_dim: Optional[int] = None, # only for stage 3 - min_chunk_size_m: float = 32, # only for stage 3 - gpu_margin_mem_ratio: float = 0.0, # only for stage 3 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0 - ) -> None: - - assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' + def __init__( + self, + seed: int = 42, + shard_init: bool = False, # only for stage 3 + placement_policy: str = "cuda", + pin_memory: bool = True, # only for stage 3 + force_outputs_fp32: bool = False, # only for stage 3 + search_range_m: int = 32, # only for stage 3 + hidden_dim: Optional[int] = None, # only for stage 3 + min_chunk_size_m: float = 32, # only for stage 3 + gpu_margin_mem_ratio: float = 0.0, # only for stage 3 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + ) -> None: + assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"' # TODO(ver217): support shard_init when using from_pretrained() if shard_init: warnings.warn( - f'Shard init is not supported model.from_pretrained() yet. ' - 'Please load weights after strategy.prepare()' + f"Shard init is not supported model.from_pretrained() yet. " + "Please load weights after strategy.prepare()" ) self.shard_init = shard_init - warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') + warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.") # NOTE: dist should be initialized before calling get_current_device() plugin_initializer = lambda: GeminiPlugin( # gemini_config device=get_current_device(), placement_policy=placement_policy, - precision='fp16', + precision="fp16", pin_memory=pin_memory, force_outputs_fp32=force_outputs_fp32, strict_ddp_mode=shard_init, @@ -187,14 +188,13 @@ def __init__(self, min_scale=min_scale, max_scale=max_scale, max_norm=max_norm, - norm_type=norm_type + norm_type=norm_type, ) super().__init__(seed, plugin_initializer) def _post_init(self) -> None: - assert isinstance(self.plugin, GeminiPlugin), \ - f'{type(self).__name__}\'s plugin is not initialized properly.' + assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly." def setup_distributed(self) -> None: colossalai.launch_from_torch({}, seed=self.seed) @@ -203,10 +203,9 @@ def model_init_context(self): world_size = dist.get_world_size() shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None - return ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_pg=shard_pg, - default_dist_spec=default_dist_spec) + return ColoInitContext( + device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec + ) def unwrap_model(self, model: nn.Module) -> nn.Module: assert isinstance(model, GeminiModel) diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index a52b0460daa8..66ff6703da4d 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -31,24 +31,21 @@ def get_grad_required_state_dict(model: nn.Module): class DDPStrategy(Strategy): """ - Strategy for distributed training using torch.distributed. + Strategy for distributed training using torch.distributed. """ - def __init__(self, - seed: int = 42, - plugin_initializer: Callable = TorchDDPPlugin - ) -> None: + def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None: self.seed = seed super().__init__(plugin_initializer) def _try_init_dist(self, force: bool = False) -> None: try: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) - dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) + dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank) torch.cuda.set_device(local_rank) except KeyError as e: if force: @@ -60,8 +57,7 @@ def _try_init_dist(self, force: bool = False) -> None: raise e def _post_init(self) -> None: - assert isinstance(self.plugin, TorchDDPPlugin), \ - f'{type(self).__name__}\'s plugin is not initialized properly.' + assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly." def setup_distributed(self) -> None: self._try_init_dist(force=True) @@ -73,12 +69,14 @@ def set_seed(self, seed: int) -> None: torch.manual_seed(seed) def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader: - return self.plugin.prepare_dataloader(data_buffer, - batch_size=data_buffer.sample_batch_size, - shuffle=True, - drop_last=True, - pin_memory=pin_memory, - collate_fn=data_buffer.collate_fn) + return self.plugin.prepare_dataloader( + data_buffer, + batch_size=data_buffer.sample_batch_size, + shuffle=True, + drop_last=True, + pin_memory=pin_memory, + collate_fn=data_buffer.collate_fn, + ) def setup_sampler(self, dataset) -> DistributedSampler: # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API. @@ -88,11 +86,9 @@ def unwrap_model(self, model: nn.Module) -> nn.Module: assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel." return model.unwrap() - def save_pretrained(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + def save_pretrained( + self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None + ) -> None: if not only_rank0 or dist.get_rank() == 0: unwrapped_model = self.unwrap_model(model) assert isinstance(unwrapped_model, (Actor, Critic, RewardModel)) @@ -103,17 +99,11 @@ def save_pretrained(self, if tokenizer is not None: tokenizer.save_pretrained(path) model_path = os.path.join(path, "pytorch_model.bin") - self.save_model(model, - model_path, - only_rank0=only_rank0) + self.save_model(model, model_path, only_rank0=only_rank0) - def _replace_keys(model_path: str, - replace_fn: Callable): + def _replace_keys(model_path: str, replace_fn: Callable): state_dict = torch.load(model_path, map_location="cpu") - state_dict = { - replace_fn(k): v - for k, v in state_dict.items() - } + state_dict = {replace_fn(k): v for k, v in state_dict.items()} torch.save(state_dict, model_path) # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin @@ -124,13 +114,13 @@ def _replace_keys(model_path: str, def get_model_state_dict_shard(self, model: nn.Module, **config): # TODO: implement sharding on naive strategy model = self.unwrap_model(model) - if 'requires_grad_only' in config and config['requires_grad_only'] == True: + if "requires_grad_only" in config and config["requires_grad_only"] == True: state_dict = get_grad_required_state_dict(model) else: state_dict = model.state_dict() - if 'shard_size' in config: - shard_size = config['shard_size'] + if "shard_size" in config: + shard_size = config["shard_size"] accumulate_size = 0 state_dict_shard = OrderedDict() for name, param in state_dict.items(): diff --git a/applications/Chat/coati/trainer/strategies/sampler.py b/applications/Chat/coati/trainer/strategies/sampler.py index d726fa640fa2..6e811bef11a5 100644 --- a/applications/Chat/coati/trainer/strategies/sampler.py +++ b/applications/Chat/coati/trainer/strategies/sampler.py @@ -4,7 +4,6 @@ class DistributedSampler: - def __init__(self, dataset, num_replicas: int, rank: int) -> None: self.dataset = dataset self.num_replicas = num_replicas @@ -12,7 +11,7 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None: if len(self.dataset) % self.num_replicas != 0: self.num_samples = math.ceil( - (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) @@ -20,10 +19,10 @@ def __init__(self, dataset, num_replicas: int, rank: int) -> None: self.total_size = self.num_samples * self.num_replicas indices = list(range(len(self.dataset))) - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples self.indices = indices diff --git a/applications/Chat/coati/trainer/utils.py b/applications/Chat/coati/trainer/utils.py index 7e2cb9c634f7..7811e7365eeb 100644 --- a/applications/Chat/coati/trainer/utils.py +++ b/applications/Chat/coati/trainer/utils.py @@ -42,7 +42,6 @@ def is_rank_0() -> bool: def to_device(x: Any, device: torch.device) -> Any: - def _to(t: Any): if isinstance(t, torch.Tensor): return t.to(device) diff --git a/applications/Chat/evaluate/config/config_cn.json b/applications/Chat/evaluate/config/config_cn.json index 023f16bef31c..4d30d005df30 100644 --- a/applications/Chat/evaluate/config/config_cn.json +++ b/applications/Chat/evaluate/config/config_cn.json @@ -70,7 +70,7 @@ "BLEU", "ROUGE", "BERTScore" - ] + ] }, "logical_reasoning": { "GPT": [ @@ -83,7 +83,7 @@ "ROUGE", "BERTScore", "CHRF" - ] + ] }, "open_qa": { "GPT": [ @@ -126,7 +126,7 @@ "conciseness" ], "Metrics": [ - ] + ] }, "Finance": { "GPT": [ @@ -134,7 +134,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Law": { "GPT": [ @@ -142,7 +142,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Education": { "GPT": [ @@ -150,7 +150,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Medical": { "GPT": [ @@ -158,7 +158,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "STEM": { "GPT": [ @@ -166,7 +166,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "SocialScience": { "GPT": [ @@ -174,7 +174,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Humanity": { "GPT": [ @@ -182,7 +182,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "Other": { "GPT": [ @@ -190,7 +190,7 @@ "correctness" ], "Metrics": [ - ] + ] }, "ethics": { "GPT": [ @@ -198,7 +198,7 @@ "correctness" ], "Metrics": [ - ] + ] } } } diff --git a/applications/Chat/evaluate/eval.py b/applications/Chat/evaluate/eval.py index e3fe0e9e091b..16ef31a94175 100644 --- a/applications/Chat/evaluate/eval.py +++ b/applications/Chat/evaluate/eval.py @@ -1,5 +1,4 @@ import argparse -import json import os import openai @@ -9,7 +8,8 @@ def main(args): assert len(args.answer_file_list) == len( - args.model_name_list), "The number of answer files and model names should be equal!" + args.model_name_list + ), "The number of answer files and model names should be equal!" # load config config = jload(args.config_file) @@ -36,7 +36,8 @@ def main(args): if len(args.model_name_list) == 1 and not gpt_evaluation_prompt: raise Exception( - "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!") + "No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!" + ) if args.gpt_model == "text-davinci-003" and args.gpt_with_reference: raise Exception( @@ -44,8 +45,15 @@ def main(args): ) # initialize evaluator - evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model, - config["language"], config.get("path_for_UniEval", None), args.gpt_with_reference) + evaluator = Evaluator( + metrics_per_category, + battle_prompt, + gpt_evaluation_prompt, + args.gpt_model, + config["language"], + config.get("path_for_UniEval", None), + args.gpt_with_reference, + ) if len(args.model_name_list) == 2: answers1 = jload(args.answer_file_list[0]) answers2 = jload(args.answer_file_list[1]) @@ -68,41 +76,41 @@ def main(args): raise ValueError(f'Unsupported language {config["language"]}!') -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='ColossalAI LLM evaluation pipeline.') - parser.add_argument('--config_file', - type=str, - default=None, - required=True, - help='path to the file of target results') - parser.add_argument('--battle_prompt_file', type=str, default=None, help='path to the prompt file for battle') - parser.add_argument('--gpt_evaluation_prompt_file', - type=str, - default=None, - help='path to the prompt file for gpt evaluation') - parser.add_argument('--target_file', type=str, default=None, help='path to the target answer (ground truth) file') - parser.add_argument('--answer_file_list', - type=str, - nargs='+', - default=[], - required=True, - help='path to the answer files of at most 2 models') - parser.add_argument('--model_name_list', - type=str, - nargs='+', - default=[], - required=True, - help='the names of at most 2 models') - parser.add_argument('--gpt_model', - default="gpt-3.5-turbo", - choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"], - help='which GPT model to use for evaluation') - parser.add_argument('--gpt_with_reference', - default=False, - action="store_true", - help='whether to include reference answer in gpt evaluation') - parser.add_argument('--save_path', type=str, default="results", help='path to save evaluation results') - parser.add_argument('--openai_key', type=str, default=None, required=True, help='Your openai key') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.") + parser.add_argument( + "--config_file", type=str, default=None, required=True, help="path to the file of target results" + ) + parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle") + parser.add_argument( + "--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation" + ) + parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file") + parser.add_argument( + "--answer_file_list", + type=str, + nargs="+", + default=[], + required=True, + help="path to the answer files of at most 2 models", + ) + parser.add_argument( + "--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models" + ) + parser.add_argument( + "--gpt_model", + default="gpt-3.5-turbo", + choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"], + help="which GPT model to use for evaluation", + ) + parser.add_argument( + "--gpt_with_reference", + default=False, + action="store_true", + help="whether to include reference answer in gpt evaluation", + ) + parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results") + parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key") args = parser.parse_args() if args.openai_key is not None: diff --git a/applications/Chat/evaluate/evaluator.py b/applications/Chat/evaluate/evaluator.py index 3dd5fd6f2f23..1d998cd2d09c 100644 --- a/applications/Chat/evaluate/evaluator.py +++ b/applications/Chat/evaluate/evaluator.py @@ -3,20 +3,27 @@ import gpt_evaluate import metrics -import pandas as pd import unieval from utils import analyze_automatic_results, get_data_per_category, save_automatic_results class Evaluator(object): """ - A class named Evaluator includes GPT-3.5/GPT-4 evaluation - and automatic evaluation + A class named Evaluator includes GPT-3.5/GPT-4 evaluation + and automatic evaluation """ - def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any], - gpt_model: str, language: str, path_for_UniEval: Dict[str, str], gpt_with_reference: bool) -> None: + def __init__( + self, + params: Dict[str, Any], + battle_prompt: Dict[str, Any], + gpt_evaluation_prompt: Dict[str, Any], + gpt_model: str, + language: str, + path_for_UniEval: Dict[str, str], + gpt_with_reference: bool, + ) -> None: self.params = params self.battle_prompt = battle_prompt self.gpt_evaluation_prompt = gpt_evaluation_prompt @@ -103,7 +110,8 @@ def switch(metric, language): if self.params[category]["UniEval"] and self.language == "cn": raise Exception( - "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.") + "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file." + ) category_metrics = self.params[category]["UniEval"] @@ -134,10 +142,9 @@ def switch(metric, language): sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]] data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list) - scores = uni_evaluator.evaluate(data, - category, - dims=list(self.unieval_metric_stats[task][category].keys()), - overall=False) + scores = uni_evaluator.evaluate( + data, category, dims=list(self.unieval_metric_stats[task][category].keys()), overall=False + ) avg_scores = unieval.calculate_average_score(scores) self.unieval_metric_stats[task][category].update(avg_scores) @@ -165,7 +172,8 @@ def switch(metric, language): category, self.gpt_model, self.language, - references=targets_per_category[category] if self.gpt_with_reference else None) + references=targets_per_category[category] if self.gpt_with_reference else None, + ) def save(self, path: str, model_name_list: List[str]) -> None: """ @@ -204,16 +212,18 @@ def save(self, path: str, model_name_list: List[str]) -> None: gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results") gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results") - all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0], - self.gpt_evaluation_results, - gpt_evaluation_results_save_path) + all_evaluations = gpt_evaluate.save_gpt_evaluation_results( + model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path + ) # Start to calculate scores and save statistics. gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics") - gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations, - gpt_evaluation_statistics_save_path) + gpt_evaluate.save_gpt_evaluation_statistics( + model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path + ) # Save charts and csv. gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses") - gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path, - gpt_evaluation_analyses_save_path) + gpt_evaluate.analyze_gpt_evaluation_statistics( + gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path + ) diff --git a/applications/Chat/evaluate/gpt_evaluate.py b/applications/Chat/evaluate/gpt_evaluate.py index 6fcbe63d0253..ad908f4ba48c 100644 --- a/applications/Chat/evaluate/gpt_evaluate.py +++ b/applications/Chat/evaluate/gpt_evaluate.py @@ -14,20 +14,18 @@ from utils import jdump, jload ref_step_template = { - "en": - "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n", - "cn": - "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n" + "en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n", + "cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n", } ref_answer_template_general = { "en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n", - "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n" + "cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n", } ref_answer_template_correctness = { "en": "\nA correct answer is as follows:\n\n{answer}\n\n", - "cn": "\n标准答案如下:\n\n{answer}\n\n" + "cn": "\n标准答案如下:\n\n{answer}\n\n", } @@ -51,10 +49,7 @@ def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: in response = openai.ChatCompletion.create( model="gpt-4", messages=[ - { - "role": "system", - "content": sys_prompt - }, + {"role": "system", "content": sys_prompt}, { "role": "user", "content": user_prompt, @@ -106,7 +101,7 @@ def parse_battle_score(evaluation: str) -> List[float]: return [float(sp[0]), float(sp[1])] else: raise Exception(f"Invalid score pair. Got {evaluation}.") - except Exception as e: + except Exception: return [-1, -1] @@ -125,9 +120,6 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any] assert len(answer1) == len(answer2) - handles = [] - evaluation_file = [] - total_len = len(answer1) question_idx_list = list(range(total_len)) @@ -140,9 +132,12 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any] assert answer1[i]["id"] == answer2[i]["id"] answer_id = answer1[i]["id"] - ques = answer1[i]["instruction"] if answer1[i][ - "input"] == "" else answer1[i]["instruction"] + " " + answer1[i]["input"] - cat = answer1[i]["category"] + ques = ( + answer1[i]["instruction"] + if answer1[i]["input"] == "" + else answer1[i]["instruction"] + " " + answer1[i]["input"] + ) + answer1[i]["category"] ans1 = answer1[i]["output"] ans2 = answer2[i]["output"] @@ -267,7 +262,11 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) -> step_to_add = ref_step_template[language] - for_the_given_answer = "{metric} (1-5) (directly give the score for the given answer):" if language == "en" else "{metric} (1-5) (直接对给定答案打分)" + for_the_given_answer = ( + "{metric} (1-5) (directly give the score for the given answer):" + if language == "en" + else "{metric} (1-5) (直接对给定答案打分)" + ) # adjective is used to describe the word "answer" in the prompt. adjective = "example" if language == "en" else "示例" @@ -280,8 +279,9 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) -> answer_to_add = ref_answer_template_correctness[language] answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"]) - step_to_add = step_to_add.format(metric=metric.lower(), - adjective=adjective) + for_the_given_answer.format(metric=metric) + step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format( + metric=metric + ) return answer_to_add + step_to_add @@ -329,7 +329,8 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens: for j in range(i): messages_to_send.append(fill_in_message("user", user_messages[j])) messages_to_send.append( - fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])) + fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"]) + ) # Length of user messages == Length of assistant messages + 1 # Because we always expect the api to response @@ -351,13 +352,15 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens: return assistant_responses[-1] -def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], - inst: Dict[str, Any], - metrics: List[str], - language: str, - reference: Dict[str, Any] = None, - model: str = "gpt-3.5-turbo", - max_tokens: int = 2048) -> Dict[str, Any]: +def get_gpt_evaluation_without_logprobs( + prompt: Dict[str, Any], + inst: Dict[str, Any], + metrics: List[str], + language: str, + reference: Dict[str, Any] = None, + model: str = "gpt-3.5-turbo", + max_tokens: int = 2048, +) -> Dict[str, Any]: """ Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer. @@ -378,7 +381,7 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], MAX_API_RETRY = 3 - question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]) + question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"] answer = inst["output"] inst["evaluation"] = {} @@ -400,10 +403,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], if prompt_reference: # Do a 2-round conversation - response = multiturn_chat_completion([prompt_1st_round, prompt_reference], - model, - max_tokens=max_tokens, - turns=2) + response = multiturn_chat_completion( + [prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2 + ) else: response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1) @@ -427,10 +429,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any], return inst -def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any], - inst: Dict[str, Any], - metrics: List[str], - max_tokens: int = 2048) -> Dict[str, Any]: +def get_gpt_evaluation_with_logprobs( + prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048 +) -> Dict[str, Any]: """ Use completion model(text-davinci-003) to evaluate one model answer. Only completion models can return log probabilities. @@ -449,7 +450,7 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any], MAX_API_RETRY = 3 - question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]) + question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"] answer = inst["output"] inst["evaluation"] = {} @@ -492,13 +493,15 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any], return inst -def evaluate(answers: List[Dict], - prompt: Dict[str, Any], - metrics: List[str], - category: str, - model: str, - language: str, - references: List[Dict] = None) -> List[Dict]: +def evaluate( + answers: List[Dict], + prompt: Dict[str, Any], + metrics: List[str], + category: str, + model: str, + language: str, + references: List[Dict] = None, +) -> List[Dict]: """ Use GPT models to evaluate model answers and save evaluation results. @@ -529,21 +532,23 @@ def evaluate(answers: List[Dict], if model == "text-davinci-003": future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1) else: - future = executor.submit(get_gpt_evaluation_without_logprobs, - prompt, - inst, - metrics, - language, - reference=None if references is None else references[idx], - model=model, - max_tokens=1) + future = executor.submit( + get_gpt_evaluation_without_logprobs, + prompt, + inst, + metrics, + language, + reference=None if references is None else references[idx], + model=model, + max_tokens=1, + ) futures.append(future) for future in tqdm.tqdm( - concurrent.futures.as_completed(futures), - desc=f"{category}: ", - total=len(futures), + concurrent.futures.as_completed(futures), + desc=f"{category}: ", + total=len(futures), ): evaluations.append(future.result()) @@ -610,12 +615,13 @@ def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) -> return int(results[0]) else: raise Exception(f"Invalid score pair. Got {evaluation}.") - except Exception as e: + except Exception: return 0 -def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[str, Any], - save_path: str) -> Dict[str, Any]: +def save_gpt_evaluation_results( + model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str +) -> Dict[str, Any]: """ Save evaluation results for different categories for one model. @@ -667,10 +673,12 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav scores[metric].append(0) elif evaluation["evaluation"][metric]["logprobs"] is not None: scores[metric].append( - calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])) + calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0]) + ) else: scores[metric].append( - calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)) + calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation) + ) statistics = {} for metric in metrics: @@ -751,9 +759,9 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv")) for category in tqdm.tqdm( - frame_per_category.keys(), - desc=f"GPT evaluation: ", - total=len(frame_per_category.keys()), + frame_per_category.keys(), + desc=f"GPT evaluation: ", + total=len(frame_per_category.keys()), ): data = pd.DataFrame(frame_per_category[category]) diff --git a/applications/Chat/evaluate/metrics.py b/applications/Chat/evaluate/metrics.py index 77f9b6e98044..85ee4de53725 100644 --- a/applications/Chat/evaluate/metrics.py +++ b/applications/Chat/evaluate/metrics.py @@ -21,13 +21,17 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str, """ bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0} cumulative_bleu = [0] * 4 - weights = [(1. / 1., 0., 0., 0.), (1. / 2., 1. / 2., 0., 0.), (1. / 3., 1. / 3., 1. / 3., 0.), - (1. / 4., 1. / 4., 1. / 4., 1. / 4.)] + weights = [ + (1.0 / 1.0, 0.0, 0.0, 0.0), + (1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0), + (1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0), + (1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0), + ] for pred, target in zip(preds, targets): if language == "cn": - pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split() - target_list = [(' '.join(jieba.cut(preprocessing_text(target)))).split()] + pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split() + target_list = [(" ".join(jieba.cut(preprocessing_text(target)))).split()] elif language == "en": pred_list = preprocessing_text(pred).split() target_list = [preprocessing_text(target).split()] @@ -42,15 +46,14 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str, def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]: - """Calculate CHRF Score Metric in sentence level. - """ + """Calculate CHRF Score Metric in sentence level.""" chrf_score = {"chrf": 0} cumulative_chrf = [] for pred, target in zip(preds, targets): if language == "cn": - pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split() - target_list = ' '.join(jieba.cut(preprocessing_text(target))).split() + pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split() + target_list = " ".join(jieba.cut(preprocessing_text(target))).split() elif language == "en": pred_list = preprocessing_text(pred).split() target_list = preprocessing_text(target).split() @@ -75,8 +78,8 @@ def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]: all_targets = [] for pred, target in zip(preds, targets): - pred_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(pred)))) - target_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(target)))) + pred_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(pred)))) + target_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(target)))) all_preds.append(pred_list) all_targets.append(target_list) @@ -99,16 +102,14 @@ def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]: longest common subsequence (LCS) between preds and targets. """ rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0} - all_preds = [] - all_targets = [] rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False) for pred, target in zip(preds, targets): score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target)) - rouge_scores["rouge1"] += score['rouge1'].fmeasure - rouge_scores["rouge2"] += score['rouge2'].fmeasure - rouge_scores["rougeL"] += score['rougeL'].fmeasure + rouge_scores["rouge1"] += score["rouge1"].fmeasure + rouge_scores["rouge2"] += score["rouge2"].fmeasure + rouge_scores["rougeL"] += score["rougeL"].fmeasure rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds) rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds) @@ -137,7 +138,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]: for pred in preds: if language == "cn": - pred_seg_list = ' '.join(jieba.cut(pred)).split() + pred_seg_list = " ".join(jieba.cut(pred)).split() count_segs = len(pred_seg_list) unique_segs = set(pred_seg_list) count_unique_chars = len(unique_segs) @@ -151,7 +152,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]: split_pred = preprocessing_text(pred).split() for n in range(0, 3): for i in range(0, len(split_pred) - n): - ngram = ' '.join(split_pred[i:i + n + 1]) + ngram = " ".join(split_pred[i : i + n + 1]) unique_ngram[n].add(ngram) all_ngram_count[n] += 1 @@ -203,8 +204,8 @@ def calculate_precision_recall_f1(preds: List[str], targets: List[str], language for pred, target in zip(preds, targets): if language == "cn": - pred_list = [char for char in ' '.join(jieba.cut(preprocessing_text(pred))).split()] - target_list = [char for char in ' '.join(jieba.cut(preprocessing_text(target))).split()] + pred_list = [char for char in " ".join(jieba.cut(preprocessing_text(pred))).split()] + target_list = [char for char in " ".join(jieba.cut(preprocessing_text(target))).split()] elif language == "en": pred_list = [char for char in preprocessing_text(pred).split()] target_list = [char for char in preprocessing_text(target).split()] diff --git a/applications/Chat/evaluate/unieval/__init__.py b/applications/Chat/evaluate/unieval/__init__.py index dad8d6ad09fa..6ffccdaa0819 100644 --- a/applications/Chat/evaluate/unieval/__init__.py +++ b/applications/Chat/evaluate/unieval/__init__.py @@ -7,6 +7,9 @@ ) __all__ = [ - 'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results', - 'analyze_unieval_results' + "get_evaluator", + "convert_data_to_unieval_format", + "calculate_average_score", + "save_unieval_results", + "analyze_unieval_results", ] diff --git a/applications/Chat/evaluate/unieval/evaluator.py b/applications/Chat/evaluate/unieval/evaluator.py index 56cc6d2f9e41..bf2bc33a95c0 100644 --- a/applications/Chat/evaluate/unieval/evaluator.py +++ b/applications/Chat/evaluate/unieval/evaluator.py @@ -28,29 +28,29 @@ class SumEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up evaluator for text summarization """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up evaluator for text summarization""" self.scorer = UniEvaluator( - model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path, + model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path, max_length=max_length, device=device, - cache_dir=cache_dir) - self.task = 'summarization' - self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance'] + cache_dir=cache_dir, + ) + self.task = "summarization" + self.dimensions = ["coherence", "consistency", "fluency", "relevance"] def evaluate(self, data, category, dims=None, overall=True): """ - Get the scores of all the given dimensions + Get the scores of all the given dimensions - category: The category to be evaluated. + category: The category to be evaluated. - dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate - four dimensions: coherence, consistency, fluency, relevance. + dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate + four dimensions: coherence, consistency, fluency, relevance. - overall: indicates whether the overall score is to be calculated. - Overall score can be customized to a combination of scores based on different - dimensions. The default here is the average score of all the given dimensions. + overall: indicates whether the overall score is to be calculated. + Overall score can be customized to a combination of scores based on different + dimensions. The default here is the average score of all the given dimensions. """ n_data = len(data) eval_scores = [{} for _ in range(n_data)] @@ -63,12 +63,12 @@ def evaluate(self, data, category, dims=None, overall=True): for dim in eval_dims: # Calculate average sentence-level scores for 'consistency' and 'fluency' - if dim == 'consistency' or dim == 'fluency': + if dim == "consistency" or dim == "fluency": src_list, output_list = [], [] - n_sents = [] # the number of sentences in each generated summary + n_sents = [] # the number of sentences in each generated summary for i in range(n_data): - source = data[i]['source'] - system_outputs = sent_tokenize(data[i]['system_output']) + source = data[i]["source"] + system_outputs = sent_tokenize(data[i]["system_output"]) n_sents.append(len(system_outputs)) for j in range(len(system_outputs)): src_list.append(source) @@ -81,24 +81,26 @@ def evaluate(self, data, category, dims=None, overall=True): score = [] for cur_n_sent in n_sents: # prevent denominator from being 0 - score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / (cur_n_sent + 1e-6)) + score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / (cur_n_sent + 1e-6)) start_idx += cur_n_sent # Calculate summary-level score for 'coherence' and 'relevance' - elif dim == 'coherence' or dim == 'relevance': + elif dim == "coherence" or dim == "relevance": src_list, output_list, ref_list = [], [], [] for i in range(n_data): - src_list.append(data[i]['source']) - output_list.append(data[i]['system_output']) - if dim == 'relevance': - ref_list.append(data[i]['reference']) + src_list.append(data[i]["source"]) + output_list.append(data[i]["system_output"]) + if dim == "relevance": + ref_list.append(data[i]["reference"]) input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task) score = self.scorer.score(input_list, self.task, category, dim) # Please customize other dimensions here for summarization else: - raise NotImplementedError('The input format for this dimension is still undefined. \ - Please customize it first.') + raise NotImplementedError( + "The input format for this dimension is still undefined. \ + Please customize it first." + ) for i in range(n_data): eval_scores[i][dim] = score[i] @@ -106,35 +108,35 @@ def evaluate(self, data, category, dims=None, overall=True): # Customize your overall score here. if overall == True: for i in range(n_data): - eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) + eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values())) return eval_scores class DialogEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up evaluator for dialogues """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up evaluator for dialogues""" self.scorer = UniEvaluator( - model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path, + model_name_or_path="MingZhong/unieval-dialog" if model_name_or_path == "" else model_name_or_path, max_length=max_length, device=device, - cache_dir=cache_dir) - self.task = 'dialogue' - self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability'] + cache_dir=cache_dir, + ) + self.task = "dialogue" + self.dimensions = ["naturalness", "coherence", "engagingness", "groundedness", "understandability"] def evaluate(self, data, category, dims=None, overall=True): """ - Get the scores of all the given dimensions + Get the scores of all the given dimensions - category: The category to be evaluated. + category: The category to be evaluated. - dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate - five dimensions: naturalness, coherence, engagingness, groundedness and understandability. + dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate + five dimensions: naturalness, coherence, engagingness, groundedness and understandability. - overall: indicates whether the overall score is to be calculated. - Overall score can be customized to a combination of scores based on different - dimensions. The default here is the average score of all the given dimensions. + overall: indicates whether the overall score is to be calculated. + Overall score can be customized to a combination of scores based on different + dimensions. The default here is the average score of all the given dimensions. """ n_data = len(data) eval_scores = [{} for _ in range(n_data)] @@ -147,50 +149,48 @@ def evaluate(self, data, category, dims=None, overall=True): for dim in eval_dims: # Calculate summation score for 'engagingness' - if dim == 'engagingness': + if dim == "engagingness": src_list, output_list, context_list = [], [], [] - n_sents = [] # the number of sentences in each generated response + n_sents = [] # the number of sentences in each generated response for i in range(n_data): - source = data[i]['source'] - context = data[i]['context'] - system_outputs = sent_tokenize(data[i]['system_output']) + source = data[i]["source"] + context = data[i]["context"] + system_outputs = sent_tokenize(data[i]["system_output"]) n_sents.append(len(system_outputs)) for j in range(len(system_outputs)): src_list.append(source) context_list.append(context) output_list.append(system_outputs[j]) - input_list = add_question(dimension=dim, - output=output_list, - src=src_list, - context=context_list, - task=self.task) + input_list = add_question( + dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task + ) sent_score = self.scorer.score(input_list, self.task, category, dim) # Get the summation score for each sample start_idx = 0 score = [] for cur_n_sent in n_sents: - score.append(sum(sent_score[start_idx:start_idx + cur_n_sent])) + score.append(sum(sent_score[start_idx : start_idx + cur_n_sent])) start_idx += cur_n_sent # Calculate turn-level score for other dimensions - elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']: + elif dim in ["naturalness", "coherence", "groundedness", "understandability"]: src_list, output_list, context_list = [], [], [] for i in range(n_data): - src_list.append(data[i]['source']) - output_list.append(data[i]['system_output']) - context_list.append(data[i]['context']) - input_list = add_question(dimension=dim, - output=output_list, - src=src_list, - context=context_list, - task=self.task) + src_list.append(data[i]["source"]) + output_list.append(data[i]["system_output"]) + context_list.append(data[i]["context"]) + input_list = add_question( + dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task + ) score = self.scorer.score(input_list, self.task, category, dim) # Please customize other dimensions here for summarization else: - raise NotImplementedError('The input format for this dimension is still undefined. \ - Please customize it first.') + raise NotImplementedError( + "The input format for this dimension is still undefined. \ + Please customize it first." + ) for i in range(n_data): eval_scores[i][dim] = score[i] @@ -198,35 +198,35 @@ def evaluate(self, data, category, dims=None, overall=True): # Customize your overall score here. if overall == True: for i in range(n_data): - eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) + eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values())) return eval_scores class D2tEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up evaluator for data-to-text """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up evaluator for data-to-text""" self.scorer = UniEvaluator( - model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path, + model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path, max_length=max_length, device=device, - cache_dir=cache_dir) - self.task = 'data2text' - self.dimensions = ['naturalness', 'informativeness'] + cache_dir=cache_dir, + ) + self.task = "data2text" + self.dimensions = ["naturalness", "informativeness"] def evaluate(self, data, category, dims=None, overall=True): """ - Get the scores of all the given dimensions + Get the scores of all the given dimensions - category: The category to be evaluated. + category: The category to be evaluated. - dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate - two dimensions: naturalness and informativeness. + dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate + two dimensions: naturalness and informativeness. - overall: indicates whether the overall score is to be calculated. - Overall score can be customized to a combination of scores based on different - dimensions. The default here is the average score of all the given dimensions. + overall: indicates whether the overall score is to be calculated. + Overall score can be customized to a combination of scores based on different + dimensions. The default here is the average score of all the given dimensions. """ n_data = len(data) eval_scores = [{} for _ in range(n_data)] @@ -240,8 +240,8 @@ def evaluate(self, data, category, dims=None, overall=True): for dim in eval_dims: output_list, ref_list = [], [] for i in range(n_data): - output_list.append(data[i]['system_output']) - ref_list.append(data[i]['reference']) + output_list.append(data[i]["system_output"]) + ref_list.append(data[i]["reference"]) input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task) score = self.scorer.score(input_list, self.task, category, dim) @@ -252,38 +252,38 @@ def evaluate(self, data, category, dims=None, overall=True): # Customize your overall score here. if overall == True: for i in range(n_data): - eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values())) + eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values())) return eval_scores class FactEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up evaluator for factual consistency detection """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up evaluator for factual consistency detection""" self.scorer = UniEvaluator( - model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path, + model_name_or_path="MingZhong/unieval-fact" if model_name_or_path == "" else model_name_or_path, max_length=max_length, device=device, - cache_dir=cache_dir) - self.task = 'fact' - self.dim = 'consistency' + cache_dir=cache_dir, + ) + self.task = "fact" + self.dim = "consistency" def evaluate(self, data, category): """ - Get the factual consistency score (only 1 dimension for this task) + Get the factual consistency score (only 1 dimension for this task) - category: The category to be evaluated. + category: The category to be evaluated. """ n_data = len(data) eval_scores = [{} for _ in range(n_data)] # Calculate average sentence-level scores for factual consistency src_list, output_list = [], [] - n_sents = [] # the number of sentences in the claim + n_sents = [] # the number of sentences in the claim for i in range(n_data): - source = data[i]['source'] - system_outputs = sent_tokenize(data[i]['system_output']) + source = data[i]["source"] + system_outputs = sent_tokenize(data[i]["system_output"]) n_sents.append(len(system_outputs)) for j in range(len(system_outputs)): src_list.append(source) @@ -295,7 +295,7 @@ def evaluate(self, data, category): start_idx = 0 score = [] for cur_n_sent in n_sents: - score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent) + score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / cur_n_sent) start_idx += cur_n_sent for i in range(n_data): @@ -304,28 +304,26 @@ def evaluate(self, data, category): return eval_scores -def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None): - assert task in ['summarization', 'dialogue', 'data2text', 'fact'] - if task == 'summarization': - return SumEvaluator(model_name_or_path=model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir) - elif task == 'dialogue': - return DialogEvaluator(model_name_or_path=model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir) - elif task == 'data2text': - return D2tEvaluator(model_name_or_path=model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir) - elif task == 'fact': - return FactEvaluator(model_name_or_path=model_name_or_path, - max_length=max_length, - device=device, - cache_dir=cache_dir) +def get_evaluator(task, model_name_or_path="", max_length=1024, device="cuda:0", cache_dir=None): + assert task in ["summarization", "dialogue", "data2text", "fact"] + if task == "summarization": + return SumEvaluator( + model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir + ) + elif task == "dialogue": + return DialogEvaluator( + model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir + ) + elif task == "data2text": + return D2tEvaluator( + model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir + ) + elif task == "fact": + return FactEvaluator( + model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir + ) else: - raise NotImplementedError('Other tasks are not implemented, \ - please customize specific tasks here.') + raise NotImplementedError( + "Other tasks are not implemented, \ + please customize specific tasks here." + ) diff --git a/applications/Chat/evaluate/unieval/scorer.py b/applications/Chat/evaluate/unieval/scorer.py index 2c70bb9f6ded..45706b833205 100644 --- a/applications/Chat/evaluate/unieval/scorer.py +++ b/applications/Chat/evaluate/unieval/scorer.py @@ -27,9 +27,8 @@ class UniEvaluator: - - def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None): - """ Set up model """ + def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None): + """Set up model""" self.device = device self.max_length = max_length @@ -47,8 +46,8 @@ def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_d def score(self, inputs, task, category, dim, batch_size=8): """ - Get scores for the given samples. - final_score = postive_score / (postive_score + negative_score) + Get scores for the given samples. + final_score = postive_score / (postive_score + negative_score) """ # The implementation of "forward" in T5 still requires decoder_input_ids. @@ -58,31 +57,27 @@ def score(self, inputs, task, category, dim, batch_size=8): pos_score_list, neg_score_list = [], [] for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "): - src_list = inputs[i:i + batch_size] - tgt_list = tgts[i:i + batch_size] + src_list = inputs[i : i + batch_size] + tgt_list = tgts[i : i + batch_size] try: with torch.no_grad(): - encoded_src = self.tokenizer(src_list, - max_length=self.max_length, - truncation=True, - padding=True, - return_tensors='pt') - encoded_tgt = self.tokenizer(tgt_list, - max_length=self.max_length, - truncation=True, - padding=True, - return_tensors='pt') - - src_tokens = encoded_src['input_ids'].to(self.device) - src_mask = encoded_src['attention_mask'].to(self.device) - - tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1) + encoded_src = self.tokenizer( + src_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + encoded_tgt = self.tokenizer( + tgt_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + + src_tokens = encoded_src["input_ids"].to(self.device) + src_mask = encoded_src["attention_mask"].to(self.device) + + tgt_tokens = encoded_tgt["input_ids"].to(self.device)[:, 0].unsqueeze(-1) output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens) logits = output.logits.view(-1, self.model.config.vocab_size) - pos_score = self.softmax(logits)[:, self.pos_id] # Yes - neg_score = self.softmax(logits)[:, self.neg_id] # No + pos_score = self.softmax(logits)[:, self.pos_id] # Yes + neg_score = self.softmax(logits)[:, self.neg_id] # No cur_pos_score = [x.item() for x in pos_score] cur_neg_score = [x.item() for x in neg_score] @@ -90,8 +85,8 @@ def score(self, inputs, task, category, dim, batch_size=8): neg_score_list += cur_neg_score except RuntimeError: - print(f'source: {src_list}') - print(f'target: {tgt_list}') + print(f"source: {src_list}") + print(f"target: {tgt_list}") exit(0) score_list = [] diff --git a/applications/Chat/evaluate/unieval/utils.py b/applications/Chat/evaluate/unieval/utils.py index a381e9e590b2..46b0f2907a30 100644 --- a/applications/Chat/evaluate/unieval/utils.py +++ b/applications/Chat/evaluate/unieval/utils.py @@ -31,105 +31,142 @@ def add_question(dimension, output, src=None, ref=None, context=None, task=None): """ - Add questions to generate input in Bool-QA format for UniEval. - - dimension: specific dimension to be evaluated - src: source input for different NLG tasks. For example, source document for summarization - and dialogue history for dialogue response generation. - output: output text generated by the models - ref: human-annotated groundtruth - context: the context needed to evaluate several specific dimension. For example, - additional factual information when evaluating engagingness and groundedness in dialogues. + Add questions to generate input in Bool-QA format for UniEval. + + dimension: specific dimension to be evaluated + src: source input for different NLG tasks. For example, source document for summarization + and dialogue history for dialogue response generation. + output: output text generated by the models + ref: human-annotated groundtruth + context: the context needed to evaluate several specific dimension. For example, + additional factual information when evaluating engagingness and groundedness in dialogues. """ input_with_question = [] for i in range(len(output)): # For summarization - if task == 'summarization': - if dimension == 'fluency': - cur_input = 'question: Is this a fluent paragraph? paragraph: ' + output[i] - elif dimension == 'coherence': - cur_input = 'question: Is this a coherent summary to the document? summary: ' + output[ - i] + ' document: ' + src[i] - elif dimension == 'consistency': - cur_input = 'question: Is this claim consistent with the document? claim: ' + output[ - i] + ' document: ' + src[i] - elif dimension == 'relevance': - cur_input = 'question: Is this summary relevant to the reference? summary: ' + output[ - i] + ' reference: ' + ref[i] + if task == "summarization": + if dimension == "fluency": + cur_input = "question: Is this a fluent paragraph? paragraph: " + output[i] + elif dimension == "coherence": + cur_input = ( + "question: Is this a coherent summary to the document? summary: " + + output[i] + + " document: " + + src[i] + ) + elif dimension == "consistency": + cur_input = ( + "question: Is this claim consistent with the document? claim: " + + output[i] + + " document: " + + src[i] + ) + elif dimension == "relevance": + cur_input = ( + "question: Is this summary relevant to the reference? summary: " + + output[i] + + " reference: " + + ref[i] + ) else: raise NotImplementedError( - 'The input format for this dimension is still undefined. Please customize it first.') + "The input format for this dimension is still undefined. Please customize it first." + ) # For dialogues - elif task == 'dialogue': - if dimension == 'naturalness': - cur_input = 'question: Is this a natural response in the dialogue? response: ' + output[i] - elif dimension == 'coherence': - cur_input = 'question: Is this a coherent response given the dialogue history? response: '\ - + output[i] + ' dialogue history: ' + src[i] - elif dimension == 'engagingness': - cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? response: '\ - + output[i] + ' dialogue history: ' + src[i] + ' fact: ' + context[i] - elif dimension == 'groundedness': - cur_input = 'question: Is this response consistent with knowledge in the fact? response: '\ - + output[i] + ' fact: ' + context[i] - elif dimension == 'understandability': - cur_input = 'question: Is this an understandable response in the dialogue? response: ' + output[i] + elif task == "dialogue": + if dimension == "naturalness": + cur_input = "question: Is this a natural response in the dialogue? response: " + output[i] + elif dimension == "coherence": + cur_input = ( + "question: Is this a coherent response given the dialogue history? response: " + + output[i] + + " dialogue history: " + + src[i] + ) + elif dimension == "engagingness": + cur_input = ( + "question: Is this an engaging and informative response according to the dialogue history and fact? response: " + + output[i] + + " dialogue history: " + + src[i] + + " fact: " + + context[i] + ) + elif dimension == "groundedness": + cur_input = ( + "question: Is this response consistent with knowledge in the fact? response: " + + output[i] + + " fact: " + + context[i] + ) + elif dimension == "understandability": + cur_input = "question: Is this an understandable response in the dialogue? response: " + output[i] else: raise NotImplementedError( - 'The input format for this dimension is still undefined. Please customize it first.') + "The input format for this dimension is still undefined. Please customize it first." + ) # For data-to-text - elif task == 'data2text': - if dimension == 'naturalness': - cur_input = 'question: Is this a fluent utterance? utterance: ' + output[i] - elif dimension == 'informativeness': - cur_input = 'question: Is this sentence informative according to the reference? sentence: '\ - + output[i] + ' reference: ' + ref[i] + elif task == "data2text": + if dimension == "naturalness": + cur_input = "question: Is this a fluent utterance? utterance: " + output[i] + elif dimension == "informativeness": + cur_input = ( + "question: Is this sentence informative according to the reference? sentence: " + + output[i] + + " reference: " + + ref[i] + ) else: raise NotImplementedError( - 'The input format for this dimension is still undefined. Please customize it first.') + "The input format for this dimension is still undefined. Please customize it first." + ) # For factual consistency detection - elif task == 'fact': - if dimension == 'consistency': - cur_input = 'question: Is this claim consistent with the document? claim: ' + output[ - i] + ' document: ' + src[i] + elif task == "fact": + if dimension == "consistency": + cur_input = ( + "question: Is this claim consistent with the document? claim: " + + output[i] + + " document: " + + src[i] + ) else: - raise NotImplementedError('No other dimensions for the factual consistency detection task.') + raise NotImplementedError("No other dimensions for the factual consistency detection task.") # For new customized tasks else: - raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.') + raise NotImplementedError("Other tasks are not implemented, please customize specific tasks here.") input_with_question.append(cur_input) return input_with_question def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None): """ - Convert the data into the unieval's format. + Convert the data into the unieval's format. - output_list: a list of model output + output_list: a list of model output - src_list: source input for different NLG tasks. For example, source document for summarization - and dialogue history for dialogue response generation - ref_list: human-annotated groundtruth + src_list: source input for different NLG tasks. For example, source document for summarization + and dialogue history for dialogue response generation + ref_list: human-annotated groundtruth """ json_data = [] for i in range(len(output_list)): cur = {} - cur['system_output'] = output_list[i] + cur["system_output"] = output_list[i] if src_list is not None: - cur['source'] = src_list[i] + cur["source"] = src_list[i] if ref_list is not None: - cur['reference'] = ref_list[i] - cur['context'] = "" + cur["reference"] = ref_list[i] + cur["context"] = "" json_data.append(cur) return json_data def calculate_average_score(scores): """ - Calculate average scores for different metrics + Calculate average scores for different metrics - scores: a list of scores for different metrics for each answer + scores: a list of scores for different metrics for each answer """ metrics = {metric: 0 for metric in scores[0]} @@ -226,9 +263,9 @@ def analyze_unieval_results(results_path: str, save_path: str) -> None: frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv")) for metric in tqdm.tqdm( - frame_per_metric.keys(), - desc=f"UniEval metrics: ", - total=len(frame_per_metric.keys()), + frame_per_metric.keys(), + desc=f"UniEval metrics: ", + total=len(frame_per_metric.keys()), ): data = pd.DataFrame(frame_per_metric[metric]) diff --git a/applications/Chat/evaluate/utils.py b/applications/Chat/evaluate/utils.py index 406e43db99aa..10df455b69d7 100644 --- a/applications/Chat/evaluate/utils.py +++ b/applications/Chat/evaluate/utils.py @@ -1,7 +1,6 @@ import io import json import os -import re import string from typing import Dict @@ -55,7 +54,7 @@ def jload(f, mode="r"): def get_json_list(file_path): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: json_list = [] for line in f: json_list.append(json.loads(line)) @@ -187,9 +186,9 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None: frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv")) for metric in tqdm.tqdm( - frame_per_metric.keys(), - desc=f"automatic metrics: ", - total=len(frame_per_metric.keys()), + frame_per_metric.keys(), + desc=f"automatic metrics: ", + total=len(frame_per_metric.keys()), ): data = pd.DataFrame(frame_per_metric[metric]) diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/Chat/examples/community/peft/easy_dataset.py index 2fe293957079..d4b17689e9cb 100644 --- a/applications/Chat/examples/community/peft/easy_dataset.py +++ b/applications/Chat/examples/community/peft/easy_dataset.py @@ -3,7 +3,6 @@ from typing import Dict, Sequence import torch -from datasets import load_dataset from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer @@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i padding="longest", max_length=max_length, truncation=True, - ) for text in strings + ) + for text in strings ] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ @@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo class EasySupervisedDataset(Dataset): - def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None: super(EasySupervisedDataset, self).__init__() with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() - #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" + # split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" sources, targets = [], [] for line in all_lines: if "回答:" in line: sep_index = line.index("回答:") - sources.append(line[:sep_index + 3]) - targets.append(line[sep_index + 3:] + tokenizer.eos_token) + sources.append(line[: sep_index + 3]) + targets.append(line[sep_index + 3 :] + tokenizer.eos_token) else: sources.append(line) targets.append("" + tokenizer.eos_token) @@ -83,15 +82,17 @@ def __str__(self): class EasyPromptsDataset(Dataset): - def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None: super(EasyPromptsDataset, self).__init__() with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() - all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines] + all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines] self.prompts = [ - tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length', - truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) + tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[ + "input_ids" + ] + .to(torch.cuda.current_device()) + .squeeze(0) for line in tqdm(all_lines) ] self.data_file = data_file @@ -110,7 +111,6 @@ def __str__(self): class EasyRewardDataset(Dataset): - def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None: super(EasyRewardDataset, self).__init__() self.chosen = [] @@ -120,44 +120,42 @@ def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None else: self.end_token = special_token print(self.end_token) - #read all lines in the train_file to a list + # read all lines in the train_file to a list with open(train_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() for line in tqdm(all_lines): data = json.loads(line) - prompt = "提问:" + data['prompt'] + " 回答:" - - chosen = prompt + data['chosen'] + self.end_token - chosen_token = tokenizer(chosen, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.chosen.append({ - "input_ids": chosen_token['input_ids'], - "attention_mask": chosen_token['attention_mask'] - }) - - reject = prompt + data['rejected'] + self.end_token - reject_token = tokenizer(reject, - max_length=max_length, - padding="max_length", - truncation=True, - return_tensors="pt") - self.reject.append({ - "input_ids": reject_token['input_ids'], - "attention_mask": reject_token['attention_mask'] - }) + prompt = "提问:" + data["prompt"] + " 回答:" + + chosen = prompt + data["chosen"] + self.end_token + chosen_token = tokenizer( + chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.chosen.append( + {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]} + ) + + reject = prompt + data["rejected"] + self.end_token + reject_token = tokenizer( + reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + self.reject.append( + {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]} + ) def __len__(self): length = len(self.chosen) return length def __getitem__(self, idx): - return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ - "input_ids"], self.reject[idx]["attention_mask"] - - #python representation of the object and the string representation of the object + return ( + self.chosen[idx]["input_ids"], + self.chosen[idx]["attention_mask"], + self.reject[idx]["input_ids"], + self.reject[idx]["attention_mask"], + ) + + # python representation of the object and the string representation of the object def __repr__(self): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" @@ -165,26 +163,25 @@ def __str__(self): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" -''' +""" Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better. If individual lines are not related, just set is_group_texts to False. -''' +""" class EasySFTDataset(Dataset): - def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None: super().__init__() - #read the data_file line by line + # read the data_file line by line with open(data_file, "r", encoding="UTF-8") as f: - #encode the text data line by line and put raw python list input_ids only to raw_input_ids list + # encode the text data line by line and put raw python list input_ids only to raw_input_ids list raw_input_ids = [] for line in f: encoded_ids = tokenizer.encode(line) - #if the encoded_ids is longer than max_length, then split it into several parts + # if the encoded_ids is longer than max_length, then split it into several parts if len(encoded_ids) > max_length: for i in range(0, len(encoded_ids), max_length): - raw_input_ids.append(encoded_ids[i:i + max_length]) + raw_input_ids.append(encoded_ids[i : i + max_length]) else: raw_input_ids.append(encoded_ids) @@ -196,12 +193,13 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ if is_group_texts: for input_ids in raw_input_ids: if len(current_input_ids) + len(input_ids) > max_length: - #pad the current_input_ids to max_length with tokenizer.pad_token_id + # pad the current_input_ids to max_length with tokenizer.pad_token_id padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( - torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long) + ) current_input_ids = [] else: current_input_ids.extend(input_ids) @@ -210,14 +208,16 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ current_input_ids.extend([tokenizer.pad_token_id] * padded_length) grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( - torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long) + ) else: - #just append the raw_input_ids to max_length + # just append the raw_input_ids to max_length for input_ids in raw_input_ids: padded_length = max_length - len(input_ids) input_ids.extend([tokenizer.pad_token_id] * padded_length) attention_mask.append( - torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long) + ) grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) self.input_ids = grouped_input_ids self.labels = copy.deepcopy(self.input_ids) @@ -227,14 +227,14 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ def __len__(self): return len(self.input_ids) - #get item from dataset + # get item from dataset def __getitem__(self, idx): return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) - #generate the dataset description to be printed by print in python + # generate the dataset description to be printed by print in python def __repr__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" - #generate the dataset description to be printed by print in python + # generate the dataset description to be printed by print in python def __str__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" diff --git a/applications/Chat/examples/community/peft/easy_models.py b/applications/Chat/examples/community/peft/easy_models.py index fe294868159d..db629e50ed94 100644 --- a/applications/Chat/examples/community/peft/easy_models.py +++ b/applications/Chat/examples/community/peft/easy_models.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from coati.models.generation import generate -from coati.models.utils import log_probs_from_logits, masked_mean +from coati.models.utils import log_probs_from_logits from peft import PeftModel from torch.nn.modules import Module from transformers import BloomConfig, BloomForCausalLM @@ -24,38 +24,33 @@ def __init__(self, model: nn.Module) -> None: @torch.no_grad() def generate( - self, - input_ids: torch.Tensor, - return_action_mask: bool = True, - **kwargs + self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: sequences = generate(self.model, input_ids, **kwargs) attention_mask = None - pad_token_id = kwargs.get('pad_token_id', None) + pad_token_id = kwargs.get("pad_token_id", None) if pad_token_id is not None: attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) if not return_action_mask: return sequences, attention_mask, None input_len = input_ids.size(1) - eos_token_id = kwargs.get('eos_token_id', None) + eos_token_id = kwargs.get("eos_token_id", None) if eos_token_id is None: action_mask = torch.ones_like(sequences, dtype=torch.bool) else: # left padding may be applied, only mask action action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input action_mask[:, :input_len] = False action_mask = action_mask[:, 1:] - return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] + return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :] - def forward(self, - sequences: torch.LongTensor, - num_actions: int, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """Returns action log probs - """ + def forward( + self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Returns action log probs""" output = self.model(sequences, attention_mask=attention_mask) - logits = output['logits'] + logits = output["logits"] log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) return log_probs[:, -num_actions:] @@ -75,11 +70,13 @@ class BLOOMActor(Actor): lora_train_bias (str): LoRA bias training mode. """ - def __init__(self, - pretrained: str = None, - config: Optional[BloomConfig] = None, - checkpoint: bool = False, - lora_path: str = None) -> None: + def __init__( + self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_path: str = None, + ) -> None: if pretrained is not None: model = BloomForCausalLM.from_pretrained(pretrained) elif config is not None: diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py index 9385e457d852..e49db1d2bc1b 100644 --- a/applications/Chat/examples/community/peft/train_peft_prompts.py +++ b/applications/Chat/examples/community/peft/train_peft_prompts.py @@ -1,18 +1,16 @@ import argparse -import pandas as pd import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.dataset import DataCollatorForSupervisedDataset from coati.models.bloom import BLOOMRM, BLOOMCritic -from coati.models.gpt import GPTRM, GPTActor, GPTCritic -from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM -from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.models.gpt import GPTRM, GPTCritic +from coati.models.llama import LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTCritic from coati.trainer import PPOTrainer from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from easy_dataset import EasyPromptsDataset, EasySupervisedDataset from easy_models import BLOOMActor -from peft import PeftModel from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler @@ -23,24 +21,24 @@ def main(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5) + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') if args.rm_path is not None: - state_dict = torch.load(args.rm_path, map_location='cpu') + state_dict = torch.load(args.rm_path, map_location="cpu") # configure model - if args.model == 'bloom': + if args.model == "bloom": # initial_model = BLOOMActor(pretrained=args.pretrain) - print('Using peft lora to load Bloom model as initial_model') + print("Using peft lora to load Bloom model as initial_model") initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) - print('Using peft lora to load Bloom model as initial_model (Done)') + print("Using peft lora to load Bloom model as initial_model (Done)") else: raise ValueError(f'Unsupported actor model "{args.model}"') @@ -49,59 +47,59 @@ def main(args): else: rm_model_name = args.rm_model - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": reward_model = GPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": print("load bloom reward model ", args.rm_pretrain) reward_model = BLOOMRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'opt': + elif rm_model_name == "opt": reward_model = OPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": reward_model = LlamaRM(pretrained=args.rm_pretrain) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') if args.rm_path is not None: - print('Loading reward model from', args.rm_path) + print("Loading reward model from", args.rm_path) reward_model.load_state_dict(state_dict) - if args.strategy != 'colossalai_gemini': + if args.strategy != "colossalai_gemini": initial_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.float16).to(torch.cuda.current_device()) with strategy.model_init_context(): - if args.model == 'bloom': + if args.model == "bloom": # actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - print('Using peft lora to load Bloom model as Actor') + print("Using peft lora to load Bloom model as Actor") actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) - print('Using peft lora to load Bloom model as Actor (Done)') + print("Using peft lora to load Bloom model as Actor (Done)") else: raise ValueError(f'Unsupported actor model "{args.model}"') - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True) critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) print("load bloom critic (Done) ") - elif rm_model_name == 'opt': + elif rm_model_name == "opt": critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') if args.rm_path is not None: - print('Loading reward model from', args.rm_path) + print("Loading reward model from", args.rm_path) critic.load_state_dict(state_dict) del state_dict - if args.strategy != 'colossalai_gemini': + if args.strategy != "colossalai_gemini": critic.to(torch.float16).to(torch.cuda.current_device()) actor.to(torch.float16).to(torch.cuda.current_device()) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): actor_optim = HybridAdam(actor.parameters(), lr=1e-7) critic_optim = HybridAdam(critic.parameters(), lr=1e-7) else: @@ -109,18 +107,18 @@ def main(args): critic_optim = Adam(critic.parameters(), lr=1e-7) # configure tokenizer - if args.model == 'gpt2': + if args.model == "gpt2": tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': + elif args.model == "opt": tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) - tokenizer.eos_token = '<\s>' + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') @@ -132,26 +130,27 @@ def main(args): prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) else: prompt_sampler = None - prompt_dataloader = DataLoader(prompt_dataset, - shuffle=(prompt_sampler is None), - sampler=prompt_sampler, - batch_size=args.train_batch_size) + prompt_dataloader = DataLoader( + prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size + ) pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) else: pretrain_sampler = None - pretrain_dataloader = DataLoader(pretrain_dataset, - shuffle=(pretrain_sampler is None), - sampler=pretrain_sampler, - batch_size=args.ptx_batch_size, - collate_fn=data_collator) + pretrain_dataloader = DataLoader( + pretrain_dataset, + shuffle=(pretrain_sampler is None), + sampler=pretrain_sampler, + batch_size=args.ptx_batch_size, + collate_fn=data_collator, + ) def tokenize_fn(texts): # MUST padding to max length to ensure inputs of all ranks have the same length # Different length may lead to hang when using gemini, as different generation steps - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True) return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) @@ -178,45 +177,46 @@ def tokenize_fn(texts): eos_token_id=tokenizer.eos_token_id, ) - trainer.fit(prompt_dataloader=prompt_dataloader, - pretrain_dataloader=pretrain_dataloader, - num_episodes=args.num_episodes, - num_update_steps=args.num_update_steps, - num_collect_steps=args.num_collect_steps) + trainer.fit( + prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + num_update_steps=args.num_update_steps, + num_collect_steps=args.num_collect_steps, + ) # save model checkpoint after fitting trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(actor_optim, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset') - parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='ddp', - help='strategy to use') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--sft_lora_path', type=str, default=None) - parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--rm_path', type=str, default=None) - parser.add_argument('--rm_pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--num_collect_steps', type=int, default=10) - parser.add_argument('--num_update_steps', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=2) - parser.add_argument('--ptx_batch_size', type=int, default=1) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--kl_coef', type=float, default=0.1) - parser.add_argument('--ptx_coef', type=float, default=0.9) + parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset") + parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") + parser.add_argument( + "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use" + ) + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--sft_lora_path", type=str, default=None) + parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--rm_path", type=str, default=None) + parser.add_argument("--rm_pretrain", type=str, default=None) + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--num_collect_steps", type=int, default=10) + parser.add_argument("--num_update_steps", type=int, default=5) + parser.add_argument("--train_batch_size", type=int, default=2) + parser.add_argument("--ptx_batch_size", type=int, default=1) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--kl_coef", type=float, default=0.1) + parser.add_argument("--ptx_coef", type=float, default=0.9) args = parser.parse_args() main(args) diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py index 4af08e6d0141..0b62dd652adb 100644 --- a/applications/Chat/examples/community/peft/train_peft_sft.py +++ b/applications/Chat/examples/community/peft/train_peft_sft.py @@ -1,18 +1,10 @@ import argparse import os -import loralib as lora import torch import torch.distributed as dist -from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset -from coati.models.base import RewardModel -from coati.models.bloom import BLOOMLM -from coati.models.gpt import GPTLM -from coati.models.llama import LlamaLM -from coati.models.opt import OPTLM from coati.trainer import SFTTrainer from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy -from datasets import load_dataset from easy_dataset import EasyDataset from peft import LoraConfig, PeftModel, TaskType, get_peft_model from torch.optim import Adam @@ -29,75 +21,76 @@ def train(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda") + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model with strategy.model_init_context(): - print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested') + print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested") model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device()) # if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json - if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \ - and os.path.exists(args.save_path + '/adapter_model.bin'): + if ( + os.path.exists(args.save_path) + and os.path.exists(args.save_path + "/adapter_config.json") + and os.path.exists(args.save_path + "/adapter_model.bin") + ): print("loading from saved peft model ", args.save_path) model = PeftModel.from_pretrained(model, args.save_path) else: # we'll use peft lora library to do the lora lora_rank = args.lora_rank if args.lora_rank > 0 else 32 # config lora with rank of lora_rank - lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=lora_rank, - lora_alpha=32, - lora_dropout=0.1) + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1 + ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': + elif args.model == "opt": tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = AutoTokenizer.from_pretrained( args.pretrain, padding_side="right", use_fast=False, ) - tokenizer.eos_token = '<\s>' + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') - if args.model == 'llama' and args.strategy == 'colossalai_gemini': + if args.model == "llama" and args.strategy == "colossalai_gemini": # this is a hack to deal with the resized embedding # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility for name, param in model.named_parameters(): if not isinstance(param, ColoParameter): - sub_module_name = '.'.join(name.split('.')[:-1]) - weight_name = name.split('.')[-1] + sub_module_name = ".".join(name.split(".")[:-1]) + weight_name = name.split(".")[-1] sub_module = model.get_submodule(sub_module_name) setattr(sub_module, weight_name, ColoParameter(param)) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) else: optim = Adam(model.parameters(), lr=args.lr) logger = get_dist_logger() - logger.set_level('WARNING') + logger.set_level("WARNING") # configure dataset law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) @@ -108,47 +101,57 @@ def train(args): eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) data_collator = default_collate if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + train_sampler = DistributedSampler( + train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) if eval_dataset is not None: - eval_sampler = DistributedSampler(eval_dataset, - shuffle=False, - seed=42, - drop_last=False, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + eval_sampler = DistributedSampler( + eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) else: train_sampler = None eval_sampler = None - train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - collate_fn=data_collator, - pin_memory=True) + train_dataloader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True, + ) if eval_dataset is not None: - eval_dataloader = DataLoader(eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - collate_fn=data_collator, - pin_memory=True) + eval_dataloader = DataLoader( + eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True, + ) else: eval_dataloader = None - trainer = SFTTrainer(model=model, - strategy=strategy, - optim=optim, - train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - batch_size=args.batch_size, - max_epochs=args.max_epochs, - accumulation_steps=args.accumulation_steps) + trainer = SFTTrainer( + model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + batch_size=args.batch_size, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + ) trainer.fit(logger=logger, log_interval=args.log_interval) @@ -156,29 +159,27 @@ def train(args): trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, - 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='ddp') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--dataset', type=str, default=None) - parser.add_argument('--eval_dataset', type=str, default=None) - parser.add_argument('--save_path', type=str, default='output') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--max_epochs', type=int, default=3) - parser.add_argument('--batch_size', type=int, default=4) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") - parser.add_argument('--lr', type=float, default=5e-6) - parser.add_argument('--accumulation_steps', type=int, default=8) - parser.add_argument('--enable_peft_lora', action='store_true', default=False) - parser.add_argument("--is_short_text", action='store_true', default=False) + parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp") + parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom") + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--eval_dataset", type=str, default=None) + parser.add_argument("--save_path", type=str, default="output") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log") + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--enable_peft_lora", action="store_true", default=False) + parser.add_argument("--is_short_text", action="store_true", default=False) args = parser.parse_args() train(args) diff --git a/applications/Chat/examples/community/ray/ray_job_script.py b/applications/Chat/examples/community/ray/ray_job_script.py index 53f304d379fe..e8a1175a9c32 100644 --- a/applications/Chat/examples/community/ray/ray_job_script.py +++ b/applications/Chat/examples/community/ray/ray_job_script.py @@ -6,16 +6,25 @@ def main(api_server_endpoint="http://127.0.0.1:8265"): client = JobSubmissionClient(api_server_endpoint) client.submit_job( - entrypoint= - "python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv", + entrypoint="python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv", runtime_env={ - "working_dir": - "applications/Chat", + "working_dir": "applications/Chat", "pip": [ - "torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain", - "tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat" - ] - }) + "torch==1.13.1", + "transformers>=4.20.1", + "datasets", + "loralib", + "colossalai>=0.2.4", + "langchain", + "tokenizers", + "fastapi", + "sse_starlette", + "wandb", + "sentencepiece", + "gpustat", + ], + }, + ) if __name__ == "__main__": diff --git a/applications/Chat/examples/community/ray/train_prompts_on_ray.py b/applications/Chat/examples/community/ray/train_prompts_on_ray.py index 1bba9ad66fbc..8abd83a8b249 100644 --- a/applications/Chat/examples/community/ray/train_prompts_on_ray.py +++ b/applications/Chat/examples/community/ray/train_prompts_on_ray.py @@ -26,9 +26,14 @@ class ExperienceCompositionRefs: - - def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef, - base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None: + def __init__( + self, + sequences_attention_mask_action_mask_ref: ray.ObjectRef, + action_log_probs_ref: ray.ObjectRef, + base_action_log_probs_ref: ray.ObjectRef, + value_ref: ray.ObjectRef, + r_ref: ray.ObjectRef, + ) -> None: self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref self.action_log_probs_ref = action_log_probs_ref self.base_action_log_probs_ref = base_action_log_probs_ref @@ -37,14 +42,14 @@ def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, acti class ExperienceMaker: - def __init__(self, kl_coef) -> None: self.kl_coef = kl_coef @torch.no_grad() def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs): sequences, attention_mask, action_mask = ray.get( - experiment_computation_refs.sequences_attention_mask_action_mask_ref) + experiment_computation_refs.sequences_attention_mask_action_mask_ref + ) action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref) base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref) r = ray.get(experiment_computation_refs.r_ref) @@ -58,11 +63,10 @@ def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs class DistributedTorchRayActor: - def __init__(self, world_size, rank, local_rank, master_addr, master_port): - logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', - level=logging.INFO, - datefmt='%Y-%m-%d %H:%M:%S') + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" + ) self._model = None self._world_size = world_size self._rank = rank @@ -82,7 +86,7 @@ def _get_current_node_ip(): @staticmethod def _get_free_port(): with socket.socket() as sock: - sock.bind(('', 0)) + sock.bind(("", 0)) return sock.getsockname()[1] def get_master_addr_port(self): @@ -90,7 +94,6 @@ def get_master_addr_port(self): class BasePPORole(DistributedTorchRayActor): - def add_experience_maker(self, kl_coef: float = 0.1): self._experience_maker = ExperienceMaker(kl_coef) @@ -99,12 +102,12 @@ def make_experience(self, experience_computation_ref: ExperienceCompositionRefs) def _init_strategy(self, strategy: str): # configure strategy - if strategy == 'ddp': + if strategy == "ddp": self._strategy = DDPStrategy() - elif strategy == 'colossalai_gemini': - self._strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) - elif strategy == 'colossalai_zero2': - self._strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') + elif strategy == "colossalai_gemini": + self._strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif strategy == "colossalai_zero2": + self._strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{strategy}"') @@ -124,11 +127,9 @@ def _prepare_model_with_strategy(self, has_optimizer: bool): def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str): raise NotImplementedError() - def init_model_from_pretrained(self, - strategy: str, - model_class: Type[LoRAModule], - pretrain: str, - has_optimizer=False): + def init_model_from_pretrained( + self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer=False + ): self._init_strategy(strategy) self._load_model_from_pretrained(model_class, pretrain) self._prepare_model_with_strategy(has_optimizer) @@ -138,7 +139,6 @@ def eval(self): class TrainablePPORole(BasePPORole): - def _load_model_from_pretrained(self, model_class, pretrain): with self._strategy.model_init_context(): self._model = model_class(pretrain).to(torch.cuda.current_device()) @@ -161,38 +161,39 @@ def learn_on_experiences(self, experience_refs): @ray.remote(num_gpus=1) class RayPPOActor(TrainablePPORole): - def set_loss_function(self, eps_clip: float): self._actor_loss_fn = PolicyLoss(eps_clip) def load_tokenizer_from_pretrained(self, model_type: str, pretrained): - if model_type == 'gpt2': + if model_type == "gpt2": self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained) self._model_tokenizer.pad_token = self._model_tokenizer.eos_token - elif model_type == 'bloom': + elif model_type == "bloom": self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained) self._model_tokenizer.pad_token = self._model_tokenizer.eos_token - elif model_type == 'opt': + elif model_type == "opt": self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained) else: raise ValueError(f'Unsupported model "{model_type}"') # Set tokenize function for sequence generation def _text_input_tokenize_fn(texts): - batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True) + batch = self._model_tokenizer(texts, return_tensors="pt", max_length=96, padding=True, truncation=True) return {k: v.cuda() for k, v in batch.items()} self._sample_tokenize_function = _text_input_tokenize_fn def setup_generate_kwargs(self, generate_kwargs: dict): from coati.trainer.ppo import _set_default_generate_kwargs + self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model) - self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id - self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id + self._generate_kwargs["pad_token_id"] = self._model_tokenizer.pad_token_id + self._generate_kwargs["eos_token_id"] = self._model_tokenizer.eos_token_id def load_csv_prompt_file_from_url_to_sampler(self, prompt_url): import pandas as pd - prompts = pd.read_csv(prompt_url)['prompt'] + + prompts = pd.read_csv(prompt_url)["prompt"] self._sampler = self._strategy.setup_sampler(prompts) def _generate(self, input_ids, **generate_kwargs): @@ -214,10 +215,9 @@ def calculate_action_log_probs(self, sequence_attention_action_mask): def _training_step(self, experience): num_actions = experience.action_mask.size(1) action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask) - actor_loss = self._actor_loss_fn(action_log_probs, - experience.action_log_probs, - experience.advantages, - action_mask=experience.action_mask) + actor_loss = self._actor_loss_fn( + action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask + ) self._strategy.backward(actor_loss, self._model, self._optimizer) self._strategy.optimizer_step(self._optimizer) self._optimizer.zero_grad() @@ -229,17 +229,18 @@ def save_checkpoint(self, save_path, should_save_optimizer: bool): self._strategy.save_model(self._model, save_path, only_rank0=True) # save optimizer checkpoint on all ranks if should_save_optimizer: - self._strategy.save_optimizer(self._optimizer, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + self._strategy.save_optimizer( + self._optimizer, + "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), + only_rank0=False, + ) def generate_answer(self, prompt, max_length=30, num_return_sequences=5): - encoded_input = self._model_tokenizer(prompt, return_tensors='pt') + encoded_input = self._model_tokenizer(prompt, return_tensors="pt") input_ids = {k: v.cuda() for k, v in encoded_input.items()} - sequence, _ = self._model.generate(**input_ids, - max_length=max_length, - return_action_mask=False, - num_return_sequences=num_return_sequences) + sequence, _ = self._model.generate( + **input_ids, max_length=max_length, return_action_mask=False, num_return_sequences=num_return_sequences + ) token_list = list(sequence.data[0]) output = " ".join([self._model_tokenizer.decode(token) for token in token_list]) return output @@ -247,18 +248,16 @@ def generate_answer(self, prompt, max_length=30, num_return_sequences=5): @ray.remote(num_gpus=1) class RayPPOCritic(TrainablePPORole): - def set_loss_function(self, value_clip: float): self._critic_loss_fn = ValueLoss(value_clip) def _training_step(self, experience): - values = self._model(experience.sequences, - action_mask=experience.action_mask, - attention_mask=experience.attention_mask) - critic_loss = self._critic_loss_fn(values, - experience.values, - experience.reward, - action_mask=experience.action_mask) + values = self._model( + experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask + ) + critic_loss = self._critic_loss_fn( + values, experience.values, experience.reward, action_mask=experience.action_mask + ) self._strategy.backward(critic_loss, self._model, self._optimizer) self._strategy.optimizer_step(self._optimizer) self._optimizer.zero_grad() @@ -272,12 +271,12 @@ def calculate_value(self, sequence_attention_action_mask): @ray.remote(num_gpus=1) class RayPPORewardModel(BasePPORole): - def _load_model_from_pretrained(self, model_class, pretrain): with self._strategy.model_init_context(): critic = model_class(pretrained=pretrain).to(torch.cuda.current_device()) - self._model = RewardModel(deepcopy(critic.model), - deepcopy(critic.value_head)).to(torch.cuda.current_device()) + self._model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to( + torch.cuda.current_device() + ) @torch.no_grad() def calculate_r(self, sequence_attention_action_mask): @@ -287,7 +286,6 @@ def calculate_r(self, sequence_attention_action_mask): @ray.remote(num_gpus=1) class RayPPOInitialModel(BasePPORole): - def _load_model_from_pretrained(self, model_class, pretrain): with self._strategy.model_init_context(): self._model = model_class(pretrain).to(torch.cuda.current_device()) @@ -300,8 +298,8 @@ def calculate_base_action_log_probs(self, sequence_attention_action_mask): class PPORayActorGroup: """ - A group of ray actors - Functions start with 'async' should return list of object refs + A group of ray actors + Functions start with 'async' should return list of object refs """ def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None: @@ -319,8 +317,9 @@ def _initiate_actors(self): pg = placement_group(bundles, strategy="STRICT_SPREAD") ray.get(pg.ready()) if pg: - master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None) + master_actor = self.ray_actor_type.options( + scheduling_strategy=PlacementGroupSchedulingStrategy(placement_group=pg, placement_group_bundle_index=0) + ).remote(world_size, 0, 0, None, None) else: master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None) self._actor_handlers = [master_actor] @@ -331,16 +330,20 @@ def _initiate_actors(self): for rank in range(1, world_size): local_rank = rank % self._num_gpus_per_node if pg: - worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote( - world_size, rank, local_rank, master_addr, master_port) + worker_actor = self.ray_actor_type.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node + ) + ).remote(world_size, rank, local_rank, master_addr, master_port) else: - worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank, - master_addr, master_port) + worker_actor = self.ray_actor_type.options(num_gpus=1).remote( + world_size, rank, local_rank, master_addr, master_port + ) self._actor_handlers.append(worker_actor) - def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str, - has_optimizer: bool): + def async_init_model_from_pretrained( + self, strategy: str, model_class: Type[LoRAModule], pretrain: str, has_optimizer: bool + ): return [ actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer) for actor in self._actor_handlers @@ -348,7 +351,6 @@ def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRA class TrainableModelRayActorGroup(PPORayActorGroup): - def async_learn_on_experiences(self, experience_refs): num_actors = len(self._actor_handlers) learn_result_refs = [] @@ -359,7 +361,6 @@ def async_learn_on_experiences(self, experience_refs): class PPOActorRayActorGroup(TrainableModelRayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPOActor) @@ -381,7 +382,8 @@ def async_calculate_action_log_probs(self, sequences_attention_mask_action_mask_ action_log_probs_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) action_log_probs_refs.append(action_log_probs_ref) return action_log_probs_refs @@ -393,7 +395,6 @@ def save_checkpoint(self, save_path, should_save_optimizer): class PPOCriticRayActorGroup(TrainableModelRayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic) @@ -402,7 +403,8 @@ def async_calculate_value(self, sequences_attention_mask_action_mask_refs): value_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): value_ref = self._actor_handlers[i % num_actors].calculate_value.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) value_refs.append(value_ref) return value_refs @@ -411,7 +413,6 @@ def set_loss_function(self, value_clip: float = 0.4): class PPOInitialRayActorGroup(PPORayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel) @@ -420,13 +421,13 @@ def async_calculate_base_action_log_probs(self, sequences_attention_mask_action_ base_action_log_probs_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) base_action_log_probs_refs.append(base_action_log_probs_ref) return base_action_log_probs_refs class PPORewardRayActorGroup(PPORayActorGroup): - def __init__(self, num_nodes, num_gpus_per_node) -> None: super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel) @@ -435,20 +436,21 @@ def async_calculate_r(self, sequences_attention_mask_action_mask_refs): r_refs = [] for i in range(len(sequences_attention_mask_action_mask_refs)): r_ref = self._actor_handlers[i % num_actors].calculate_r.remote( - sequences_attention_mask_action_mask_refs[i]) + sequences_attention_mask_action_mask_refs[i] + ) r_refs.append(r_ref) return r_refs def main(args): - logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', - level=logging.INFO, - datefmt='%Y-%m-%d %H:%M:%S') - if args.model == 'gpt2': + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S" + ) + if args.model == "gpt2": actor_model_class, critic_model_class = GPTActor, GPTCritic - elif args.model == 'bloom': + elif args.model == "bloom": actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic - elif args.model == 'opt': + elif args.model == "opt": actor_model_class, critic_model_class = OPTActor, OPTCritic else: raise ValueError(f'Unsupported model "{args.model}"') @@ -462,13 +464,14 @@ def main(args): logging.info("Actors created") # Prepare model for training - generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50} + generate_kwargs = {"max_length": 128, "do_sample": True, "temperature": 1.0, "top_k": 50} ray.get( - actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) + - critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) + - initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) + - reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) + - actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs)) + actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) + + critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) + + initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) + + reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) + + actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs) + ) logging.info("Models prepared for training") # Prepare models for training @@ -483,8 +486,12 @@ def main(args): # Start training logging.info("Training start") # Set all models to eval and add experience maker - all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \ - initial_group._actor_handlers + reward_group._actor_handlers + all_ray_actors = ( + actor_group._actor_handlers + + critic_group._actor_handlers + + initial_group._actor_handlers + + reward_group._actor_handlers + ) num_ray_actors = len(all_ray_actors) ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors]) ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors]) @@ -497,18 +504,28 @@ def main(args): time += 1 # Experience queueing stage sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence( - experience_batch_size) + experience_batch_size + ) base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs( - sequences_attention_mask_action_mask_refs) + sequences_attention_mask_action_mask_refs + ) values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs) r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs) action_log_probs_refs = actor_group.async_calculate_action_log_probs( - sequences_attention_mask_action_mask_refs) - experience_composition_refs.extend([ - ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i], - base_action_log_probs_refs[i], values_refs[i], r_refs[i]) - for i in range(len(sequences_attention_mask_action_mask_refs)) - ]) + sequences_attention_mask_action_mask_refs + ) + experience_composition_refs.extend( + [ + ExperienceCompositionRefs( + sequences_attention_mask_action_mask_refs[i], + action_log_probs_refs[i], + base_action_log_probs_refs[i], + values_refs[i], + r_refs[i], + ) + for i in range(len(sequences_attention_mask_action_mask_refs)) + ] + ) # Learning stage if time % update_timesteps == 0: experience_refs = [] @@ -519,8 +536,9 @@ def main(args): experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref)) # backward ray.get( - actor_group.async_learn_on_experiences(experience_refs) + - critic_group.async_learn_on_experiences(experience_refs)) + actor_group.async_learn_on_experiences(experience_refs) + + critic_group.async_learn_on_experiences(experience_refs) + ) # clear refs queue experience_composition_refs.clear() logging.info("Training finished") @@ -528,26 +546,24 @@ def main(args): actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_csv_url', type=str) - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='ddp') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) - parser.add_argument('--pretrain', type=str, default='gpt2') - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--max_timesteps', type=int, default=10) - parser.add_argument('--update_timesteps', type=int, default=10) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1) - parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1) - parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1) - parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1) - parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1) + parser.add_argument("--prompt_csv_url", type=str) + parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt"]) + parser.add_argument("--pretrain", type=str, default="gpt2") + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts.pt") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--max_timesteps", type=int, default=10) + parser.add_argument("--update_timesteps", type=int, default=10) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--num_actor_nodes", type=int, help="num of nodes to use to host actor model", default=1) + parser.add_argument("--num_critic_nodes", type=int, help="num of nodes to use to host critic model", default=1) + parser.add_argument("--num_initial_nodes", type=int, help="num of nodes to use to host initial model", default=1) + parser.add_argument("--num_reward_nodes", type=int, help="num of nodes to use to host reward model", default=1) + parser.add_argument("--num_gpus_per_node", type=int, help="num of gpus on a ray node", default=1) args = parser.parse_args() ray.init() main(args) diff --git a/applications/Chat/examples/download_model.py b/applications/Chat/examples/download_model.py index c2b5f9a859a9..ec3482b5f789 100644 --- a/applications/Chat/examples/download_model.py +++ b/applications/Chat/examples/download_model.py @@ -22,7 +22,7 @@ def download(self, dir_path: str): file_path = hf_hub_download(self.repo_id, file, local_dir=dir_path) def download_all(self): - file_path = snapshot_download(self.repo_id) + snapshot_download(self.repo_id) def test_init(model: str, dir_path: str): @@ -31,19 +31,19 @@ def test_init(model: str, dir_path: str): actor = GPTActor(config=config) critic = GPTCritic(config=config) reward_model = GPTRM(config=config) - tokenizer = GPT2Tokenizer.from_pretrained(dir_path) + GPT2Tokenizer.from_pretrained(dir_path) elif model == "bloom": config = BloomConfig.from_pretrained(dir_path) actor = BLOOMActor(config=config) critic = BLOOMCritic(config=config) reward_model = BLOOMRM(config=config) - tokenizer = BloomTokenizerFast.from_pretrained(dir_path) + BloomTokenizerFast.from_pretrained(dir_path) elif model == "opt": config = AutoConfig.from_pretrained(dir_path) actor = OPTActor(config=config) critic = OPTCritic(config=config) reward_model = OPTRM(config=config) - tokenizer = AutoTokenizer.from_pretrained(dir_path) + AutoTokenizer.from_pretrained(dir_path) else: raise NotImplementedError(f"Model {model} not implemented") @@ -59,17 +59,12 @@ def test_init(model: str, dir_path: str): exit(0) repo_list = { - "gpt2": HFRepoFiles( - repo_id="gpt2", - files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"] - ), + "gpt2": HFRepoFiles(repo_id="gpt2", files=["config.json", "tokenizer.json", "vocab.json", "merges.txt"]), "bloom": HFRepoFiles( - repo_id="bigscience/bloom-560m", - files=["config.json", "tokenizer.json", "tokenizer_config.json"] + repo_id="bigscience/bloom-560m", files=["config.json", "tokenizer.json", "tokenizer_config.json"] ), "opt": HFRepoFiles( - repo_id="facebook/opt-350m", - files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"] + repo_id="facebook/opt-350m", files=["config.json", "tokenizer_config.json", "vocab.json", "merges.txt"] ), } diff --git a/applications/Chat/examples/generate_conversation_dataset.py b/applications/Chat/examples/generate_conversation_dataset.py index 8d2fbba955b8..7e03b2d54260 100644 --- a/applications/Chat/examples/generate_conversation_dataset.py +++ b/applications/Chat/examples/generate_conversation_dataset.py @@ -31,9 +31,11 @@ def generate_alpaca(): def generate_sharegpt(): # ShareGPT data requires less processing. conversation_dataset = [] - dataset = load_dataset("anon8231489123/ShareGPT_Vicuna_unfiltered", - data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json", - split="train") + dataset = load_dataset( + "anon8231489123/ShareGPT_Vicuna_unfiltered", + data_files="ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json", + split="train", + ) conversations = dataset["conversations"] @@ -43,23 +45,24 @@ def generate_sharegpt(): del conv["markdown"] del conv["text"] - conversation = dict(type="conversation", - language="Multilingual", - dataset="ShareGPT", - conversations=conversations[idx]) + conversation = dict( + type="conversation", language="Multilingual", dataset="ShareGPT", conversations=conversations[idx] + ) conversation_dataset.append(conversation) return conversation_dataset -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--dataset', - type=str, - default="All", - choices=["Alpaca", "ShareGPT", "All"], - help="which dataset to convert, All will combine Alpaca and ShareGPT") - parser.add_argument('--save_path', type=str, default="dataset.json", help="path to save the converted dataset") + parser.add_argument( + "--dataset", + type=str, + default="All", + choices=["Alpaca", "ShareGPT", "All"], + help="which dataset to convert, All will combine Alpaca and ShareGPT", + ) + parser.add_argument("--save_path", type=str, default="dataset.json", help="path to save the converted dataset") args = parser.parse_args() conversation_dataset = [] @@ -75,5 +78,5 @@ def generate_sharegpt(): for idx, sample in enumerate(conversation_dataset): sample["id"] = idx + 1 - with open(args.save_path, mode='w') as f: + with open(args.save_path, mode="w") as f: json.dump(conversation_dataset, f, indent=4, default=str, ensure_ascii=False) diff --git a/applications/Chat/examples/generate_prompt_dataset.py b/applications/Chat/examples/generate_prompt_dataset.py index 2abb31c09f82..4eec6feae505 100644 --- a/applications/Chat/examples/generate_prompt_dataset.py +++ b/applications/Chat/examples/generate_prompt_dataset.py @@ -6,7 +6,7 @@ def sample(args): - with open(args.dataset_path, mode='r') as f: + with open(args.dataset_path, mode="r") as f: dataset_list = json.load(f) sampled_dataset = [ @@ -14,18 +14,14 @@ def sample(args): for idx, sample in enumerate(random.sample(dataset_list, args.sample_size)) ] - with open(args.save_path, mode='w') as f: - json.dump(sampled_dataset, f, indent=4, - default=str, ensure_ascii=False) + with open(args.save_path, mode="w") as f: + json.dump(sampled_dataset, f, indent=4, default=str, ensure_ascii=False) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--dataset_path', type=str, default=None, - required=True, help="path to the pretrain dataset") - parser.add_argument('--save_path', type=str, default='prompt.json', - help="path to save the prompt dataset") - parser.add_argument('--sample_size', type=int, - default=16384, help="size of the prompt dataset") + parser.add_argument("--dataset_path", type=str, default=None, required=True, help="path to the pretrain dataset") + parser.add_argument("--save_path", type=str, default="prompt.json", help="path to save the prompt dataset") + parser.add_argument("--sample_size", type=int, default=16384, help="size of the prompt dataset") args = parser.parse_args() sample(args) diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py index e1e57e3cd376..087c49564e43 100644 --- a/applications/Chat/examples/inference.py +++ b/applications/Chat/examples/inference.py @@ -11,13 +11,13 @@ def eval(args): # configure model - if args.model == 'gpt2': + if args.model == "gpt2": actor = GPTActor(pretrained=args.pretrain) - elif args.model == 'bloom': + elif args.model == "bloom": actor = BLOOMActor(pretrained=args.pretrain) - elif args.model == 'opt': + elif args.model == "opt": actor = OPTActor(pretrained=args.pretrain) - elif args.model == 'llama': + elif args.model == "llama": actor = LlamaActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported model "{args.model}"') @@ -28,45 +28,38 @@ def eval(args): actor.load_state_dict(state_dict) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': - tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m') + elif args.model == "bloom": + tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': + elif args.model == "opt": tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - tokenizer.eos_token = '<\s>' + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') actor.eval() - input_ids = tokenizer.encode(args.input, - return_tensors='pt')\ - .to(torch.cuda.current_device()) - outputs = generate(actor, - input_ids, - max_length=args.max_length, - do_sample=True, - top_k=50, - top_p=0.95, - num_return_sequences=1) - output = tokenizer.batch_decode(outputs[0], - skip_special_tokens=True) + input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device()) + outputs = generate( + actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1 + ) + output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) print(f"[Output]: {''.join(output)}") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--model_path', type=str, default=None) - parser.add_argument('--input', type=str, default='Question: How are you ? Answer:') - parser.add_argument('--max_length', type=int, default=100) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--model_path", type=str, default=None) + parser.add_argument("--input", type=str, default="Question: How are you ? Answer:") + parser.add_argument("--max_length", type=int, default=100) args = parser.parse_args() eval(args) diff --git a/applications/Chat/examples/ray/1mmt_prompt.py b/applications/Chat/examples/ray/1mmt_prompt.py index 5dd52f1790e6..8de6219ec4e9 100644 --- a/applications/Chat/examples/ray/1mmt_prompt.py +++ b/applications/Chat/examples/ray/1mmt_prompt.py @@ -5,7 +5,6 @@ import pandas as pd import ray -import torch from coati.quant import llama_load_quant, low_resource_init from coati.ray.detached_trainer_ppo import DetachedPPOTrainer from coati.ray.experience_maker_holder import ExperienceMakerHolder @@ -23,13 +22,13 @@ def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] def get_local_ip(): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) + s.connect(("8.8.8.8", 80)) return s.getsockname()[0] @@ -37,22 +36,25 @@ def main(args): master_addr = str(get_local_ip()) # trainer_env_info trainer_port = str(get_free_port()) - env_info_trainers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_trainers), - 'master_port': trainer_port, - 'master_addr': master_addr - } for rank in range(args.num_trainers)] + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] # maker_env_info maker_port = str(get_free_port()) env_info_maker = { - 'local_rank': '0', - 'rank': '0', - 'world_size': '1', - 'master_port': maker_port, - 'master_addr': master_addr + "local_rank": "0", + "rank": "0", + "world_size": "1", + "master_port": maker_port, + "master_addr": master_addr, } # configure tokenizer @@ -75,27 +77,33 @@ def trainer_model_fn(): eval_performance=True, debug=args.debug, update_lora_weights=not (args.lora_rank == 0), - ) for i, env_info_trainer in enumerate(env_info_trainers) + ) + for i, env_info_trainer in enumerate(env_info_trainers) ] def model_fn(): actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() - if args.initial_model_quant_ckpt is not None and args.model == 'llama': + if args.initial_model_quant_ckpt is not None and args.model == "llama": # quantize initial model actor_cfg = AutoConfig.from_pretrained(args.pretrain) with low_resource_init(), no_init_weights(): initial_model = get_actor_from_args(args.model, config=actor_cfg) - initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, - args.quant_group_size).cuda().requires_grad_(False) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) else: initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() return actor, critic, reward_model, initial_model # configure Experience Maker experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( - detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)], + detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)], strategy_fn=partial(get_strategy_from_args, args.maker_strategy), model_fn=model_fn, env_info=env_info_maker, @@ -130,12 +138,11 @@ def model_fn(): dataset_size = args.experience_batch_size * 4 def build_dataloader(): - def tokenize_fn(texts): - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True) return {k: v.cuda() for k, v in batch.items()} - dataset = pd.read_csv(args.prompt_path)['prompt'] + dataset = pd.read_csv(args.prompt_path)["prompt"] dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn) return dataloader @@ -144,32 +151,31 @@ def tokenize_fn(texts): ray.get(wait_tasks) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_path', type=str, default=None) - parser.add_argument('--num_trainers', type=int, default=1) - parser.add_argument('--trainer_strategy', - choices=[ - 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', - 'colossalai_zero2_cpu' - ], - default='ddp') - parser.add_argument('--maker_strategy', choices=['naive'], default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--critic_pretrain', type=str, default=None) - parser.add_argument('--experience_steps', type=int, default=4) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--train_epochs', type=int, default=1) - parser.add_argument('--update_steps', type=int, default=2) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) - parser.add_argument('--quant_bits', type=int, default=4) - parser.add_argument('--quant_group_size', type=int, default=128) - parser.add_argument('--debug', action='store_true') + parser.add_argument("--prompt_path", type=str, default=None) + parser.add_argument("--num_trainers", type=int, default=1) + parser.add_argument( + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") args = parser.parse_args() ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) main(args) diff --git a/applications/Chat/examples/ray/mmmt_prompt.py b/applications/Chat/examples/ray/mmmt_prompt.py index 76929c9d0144..7c03a0468b02 100644 --- a/applications/Chat/examples/ray/mmmt_prompt.py +++ b/applications/Chat/examples/ray/mmmt_prompt.py @@ -5,7 +5,6 @@ import pandas as pd import ray -import torch from coati.quant import llama_load_quant, low_resource_init from coati.ray.detached_trainer_ppo import DetachedPPOTrainer from coati.ray.experience_maker_holder import ExperienceMakerHolder @@ -23,13 +22,13 @@ def get_free_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] def get_local_ip(): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(('8.8.8.8', 80)) + s.connect(("8.8.8.8", 80)) return s.getsockname()[0] @@ -37,23 +36,29 @@ def main(args): master_addr = str(get_local_ip()) # trainer_env_info trainer_port = str(get_free_port()) - env_info_trainers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_trainers), - 'master_port': trainer_port, - 'master_addr': master_addr - } for rank in range(args.num_trainers)] + env_info_trainers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_trainers), + "master_port": trainer_port, + "master_addr": master_addr, + } + for rank in range(args.num_trainers) + ] # maker_env_info maker_port = str(get_free_port()) - env_info_makers = [{ - 'local_rank': '0', - 'rank': str(rank), - 'world_size': str(args.num_makers), - 'master_port': maker_port, - 'master_addr': master_addr - } for rank in range(args.num_makers)] + env_info_makers = [ + { + "local_rank": "0", + "rank": str(rank), + "world_size": str(args.num_makers), + "master_port": maker_port, + "master_addr": master_addr, + } + for rank in range(args.num_makers) + ] # configure tokenizer tokenizer = AutoTokenizer.from_pretrained(args.pretrain) @@ -63,13 +68,18 @@ def model_fn(): actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda() - if args.initial_model_quant_ckpt is not None and args.model == 'llama': + if args.initial_model_quant_ckpt is not None and args.model == "llama": # quantize initial model actor_cfg = AutoConfig.from_pretrained(args.pretrain) with low_resource_init(), no_init_weights(): initial_model = get_actor_from_args(args.model, config=actor_cfg) - initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, - args.quant_group_size).cuda().requires_grad_(False) + initial_model.model = ( + llama_load_quant( + initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size + ) + .cuda() + .requires_grad_(False) + ) else: initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda() return actor, critic, reward_model, initial_model @@ -78,7 +88,7 @@ def model_fn(): experience_holder_refs = [ ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote( detached_trainer_name_list=[ - f'trainer{x}' + f"trainer{x}" for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False) ], strategy_fn=partial(get_strategy_from_args, args.maker_strategy), @@ -87,8 +97,8 @@ def model_fn(): kl_coef=0.1, debug=args.debug, update_lora_weights=not (args.lora_rank == 0), - # sync_models_from_trainers=True, - # generation kwargs: + # sync_models_from_trainers=True, + # generation kwargs: max_length=512, do_sample=True, temperature=1.0, @@ -128,12 +138,11 @@ def trainer_model_fn(): dataset_size = args.experience_batch_size * 4 def build_dataloader(): - def tokenize_fn(texts): - batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True) return {k: v.cuda() for k, v in batch.items()} - dataset = pd.read_csv(args.prompt_path)['prompt'] + dataset = pd.read_csv(args.prompt_path)["prompt"] dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn) return dataloader @@ -148,39 +157,44 @@ def tokenize_fn(texts): for experience_holder_ref in experience_holder_refs: wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps)) - total_steps = args.experience_batch_size * args.experience_steps * \ - args.num_makers // (args.num_trainers * args.train_batch_size) + total_steps = ( + args.experience_batch_size + * args.experience_steps + * args.num_makers + // (args.num_trainers * args.train_batch_size) + ) for trainer_ref in trainer_refs: wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) ray.get(wait_tasks) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_path', type=str, default=None) - parser.add_argument('--num_makers', type=int, default=1) - parser.add_argument('--num_trainers', type=int, default=1) + parser.add_argument("--prompt_path", type=str, default=None) + parser.add_argument("--num_makers", type=int, default=1) + parser.add_argument("--num_trainers", type=int, default=1) parser.add_argument( - '--trainer_strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', 'colossalai_zero2_cpu'], - default='ddp') - parser.add_argument('--maker_strategy', choices=['naive'], default='naive') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--critic_pretrain', type=str, default=None) - parser.add_argument('--experience_steps', type=int, default=4) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--train_epochs', type=int, default=1) - parser.add_argument('--update_steps', type=int, default=2) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - - parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) - parser.add_argument('--quant_bits', type=int, default=4) - parser.add_argument('--quant_group_size', type=int, default=128) - parser.add_argument('--debug', action='store_true') + "--trainer_strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"], + default="ddp", + ) + parser.add_argument("--maker_strategy", choices=["naive"], default="naive") + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--critic_pretrain", type=str, default=None) + parser.add_argument("--experience_steps", type=int, default=4) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--train_epochs", type=int, default=1) + parser.add_argument("--update_steps", type=int, default=2) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument("--initial_model_quant_ckpt", type=str, default=None) + parser.add_argument("--quant_bits", type=int, default=4) + parser.add_argument("--quant_group_size", type=int, default=128) + parser.add_argument("--debug", action="store_true") args = parser.parse_args() ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt index 5d0f9f927d17..d3ea7b0c8142 100644 --- a/applications/Chat/examples/requirements.txt +++ b/applications/Chat/examples/requirements.txt @@ -1,3 +1,3 @@ pandas>=1.4.1 sentencepiece -colossalai==0.3.1 \ No newline at end of file +colossalai==0.3.1 diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index d27a70a3fef6..ad688b07a7f2 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -20,28 +20,28 @@ def main(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') if args.rm_path is not None: - warnings.warn('LoRA weights should be merged with the model weights') - state_dict = torch.load(args.rm_path, map_location='cpu') + warnings.warn("LoRA weights should be merged with the model weights") + state_dict = torch.load(args.rm_path, map_location="cpu") with strategy.model_init_context(): # configure model - if args.model == 'gpt2': + if args.model == "gpt2": initial_model = GPTActor(pretrained=args.pretrain) - elif args.model == 'bloom': + elif args.model == "bloom": initial_model = BLOOMActor(pretrained=args.pretrain) - elif args.model == 'opt': + elif args.model == "opt": initial_model = OPTActor(pretrained=args.pretrain) - elif args.model == 'llama': + elif args.model == "llama": initial_model = LlamaActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported actor model "{args.model}"') @@ -51,13 +51,13 @@ def main(args): else: rm_model_name = args.rm_model - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": reward_model = GPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": reward_model = BLOOMRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == 'opt': + elif rm_model_name == "opt": reward_model = OPTRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": reward_model = LlamaRM(pretrained=args.rm_pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') @@ -68,24 +68,24 @@ def main(args): initial_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.to(torch.float16).to(torch.cuda.current_device()) - if args.model == 'gpt2': + if args.model == "gpt2": actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'bloom': + elif args.model == "bloom": actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'opt': + elif args.model == "opt": actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'llama': + elif args.model == "llama": actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported actor model "{args.model}"') - if rm_model_name == 'gpt2': + if rm_model_name == "gpt2": critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'bloom': + elif rm_model_name == "bloom": critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'opt': + elif rm_model_name == "opt": critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) - elif rm_model_name == 'llama': + elif rm_model_name == "llama": critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') @@ -94,12 +94,12 @@ def main(args): critic.load_state_dict(state_dict, strict=False) del state_dict - if args.strategy != 'colossalai_gemini': + if args.strategy != "colossalai_gemini": critic.to(torch.float16).to(torch.cuda.current_device()) actor.to(torch.float16).to(torch.cuda.current_device()) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): actor_optim = HybridAdam(actor.parameters(), lr=1e-7) critic_optim = HybridAdam(critic.parameters(), lr=1e-7) else: @@ -107,22 +107,22 @@ def main(args): critic_optim = Adam(critic.parameters(), lr=1e-7) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained( - 'gpt2' if args.tokenizer is None else args.tokenizer) + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained( - 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) + "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer + ) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained( - "facebook/opt-350m" if args.tokenizer is None else args.tokenizer) + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained( - "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) - tokenizer.eos_token = '<\s>' + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer + ) + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') @@ -132,27 +132,25 @@ def main(args): prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) else: prompt_sampler = None - prompt_dataloader = DataLoader(prompt_dataset, - shuffle=(prompt_sampler is None), - sampler=prompt_sampler, - batch_size=args.experience_batch_size) - - pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, - data_path=args.pretrain_dataset, - max_datasets_size=16384, - max_length=args.max_input_len) + prompt_dataloader = DataLoader( + prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.experience_batch_size + ) + + pretrain_dataset = SupervisedDataset( + tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384, max_length=args.max_input_len + ) if dist.is_initialized() and dist.get_world_size() > 1: pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) else: pretrain_sampler = None - pretrain_dataloader = DataLoader(pretrain_dataset, - shuffle=(pretrain_sampler is None), - sampler=pretrain_sampler, - batch_size=args.ptx_batch_size) + pretrain_dataloader = DataLoader( + pretrain_dataset, shuffle=(pretrain_sampler is None), sampler=pretrain_sampler, batch_size=args.ptx_batch_size + ) # NOTE: For small models like opt-1.3b, reward model and initial model are not required to be parallelized. - (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = \ - strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model) + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare( + (actor, actor_optim), (critic, critic_optim), reward_model, initial_model + ) # configure trainer trainer = PPOTrainer( @@ -173,50 +171,54 @@ def main(args): top_k=50, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, - offload_inference_models=args.strategy != 'colossalai_gemini' + offload_inference_models=args.strategy != "colossalai_gemini", ) - trainer.fit(prompt_dataloader=prompt_dataloader, - pretrain_dataloader=pretrain_dataloader, - num_episodes=args.num_episodes, - num_collect_steps=args.num_collect_steps, - num_update_steps=args.num_update_steps) + trainer.fit( + prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + num_collect_steps=args.num_collect_steps, + num_update_steps=args.num_update_steps, + ) # save model checkpoint after fitting strategy.save_model(actor, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(actor_optim, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--prompt_dataset', type=str, default=None, help='path to the prompt dataset') - parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='colossalai_zero2', - help='strategy to use') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--tokenizer', type=str, default=None) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) - parser.add_argument('--rm_path', type=str, default=None) - parser.add_argument('--rm_pretrain', type=str, default=None) - parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--num_episodes', type=int, default=10) - parser.add_argument('--num_collect_steps', type=int, default=10) - parser.add_argument('--num_update_steps', type=int, default=5) - parser.add_argument('--train_batch_size', type=int, default=8) - parser.add_argument('--ptx_batch_size', type=int, default=1) - parser.add_argument('--experience_batch_size', type=int, default=8) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--kl_coef', type=float, default=0.1) - parser.add_argument('--ptx_coef', type=float, default=0.9) - parser.add_argument('--max_input_len', type=int, default=96) - parser.add_argument('--max_seq_len', type=int, default=128) + parser.add_argument("--prompt_dataset", type=str, default=None, help="path to the prompt dataset") + parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset") + parser.add_argument( + "--strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2"], + default="colossalai_zero2", + help="strategy to use", + ) + parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"]) + parser.add_argument("--rm_path", type=str, default=None) + parser.add_argument("--rm_pretrain", type=str, default=None) + parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--num_collect_steps", type=int, default=10) + parser.add_argument("--num_update_steps", type=int, default=5) + parser.add_argument("--train_batch_size", type=int, default=8) + parser.add_argument("--ptx_batch_size", type=int, default=1) + parser.add_argument("--experience_batch_size", type=int, default=8) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--kl_coef", type=float, default=0.1) + parser.add_argument("--ptx_coef", type=float, default=0.9) + parser.add_argument("--max_input_len", type=int, default=96) + parser.add_argument("--max_seq_len", type=int, default=128) args = parser.parse_args() main(args) diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index 190460bc20f6..a07f4b5ca812 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -24,24 +24,24 @@ def train(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda") + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model with strategy.model_init_context(): - if args.model == 'bloom': + if args.model == "bloom": model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'opt': + elif args.model == "opt": model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'gpt2': + elif args.model == "gpt2": model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank) - elif args.model == 'llama': + elif args.model == "llama": model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported model "{args.model}"') @@ -53,36 +53,36 @@ def train(args): model.load_state_dict(state_dict) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained( - 'gpt2' if args.tokenizer is None else args.tokenizer) + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained( - 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) + "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer + ) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained( - "facebook/opt-350m" if args.tokenizer is None else args.tokenizer) + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained( - "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) - tokenizer.eos_token = '<\s>' + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer + ) + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token else: raise ValueError(f'Unsupported model "{args.model}"') # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): optim = HybridAdam(model.parameters(), lr=5e-6) else: optim = Adam(model.parameters(), lr=5e-6) # configure loss function - if args.loss_fn == 'log_sig': + if args.loss_fn == "log_sig": loss_fn = LogSigLoss() - elif args.loss_fn == 'log_exp': + elif args.loss_fn == "log_exp": loss_fn = LogExpLoss() else: raise ValueError(f'Unsupported loss function "{args.loss_fn}"') @@ -94,18 +94,18 @@ def train(args): data = load_dataset(args.dataset) if args.test: - train_data = data['train'].select(range(20)) - eval_data = data['test'].select(range(5)) + train_data = data["train"].select(range(20)) + eval_data = data["test"].select(range(5)) else: - train_data = data['train'] - eval_data = data['test'] - valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5))) + train_data = data["train"] + eval_data = data["test"] + valid_data = data["test"].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5))) - if args.dataset == 'Dahoas/rm-static': + if args.dataset == "Dahoas/rm-static": train_dataset = RmStaticDataset(train_data, tokenizer, args.max_len) valid_dataset = RmStaticDataset(valid_data, tokenizer, args.max_len) eval_dataset = RmStaticDataset(eval_data, tokenizer, args.max_len) - elif args.dataset == 'Anthropic/hh-rlhf': + elif args.dataset == "Anthropic/hh-rlhf": train_dataset = HhRlhfDataset(train_data, tokenizer, args.max_len) valid_dataset = HhRlhfDataset(valid_data, tokenizer, args.max_len) eval_dataset = HhRlhfDataset(eval_data, tokenizer, args.max_len) @@ -113,90 +113,99 @@ def train(args): raise ValueError(f'Unsupported dataset "{args.dataset}"') if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) - valid_sampler = DistributedSampler(valid_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) - eval_sampler = DistributedSampler(eval_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + train_sampler = DistributedSampler( + train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) + valid_sampler = DistributedSampler( + valid_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) + eval_sampler = DistributedSampler( + eval_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) else: train_sampler = None valid_sampler = None eval_sampler = None - train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - pin_memory=True) - - valid_dataloader = DataLoader(valid_dataset, - shuffle=(valid_sampler is None), - sampler=valid_sampler, - batch_size=args.batch_size, - pin_memory=True) - - eval_dataloader = DataLoader(eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - pin_memory=True) + train_dataloader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) + + valid_dataloader = DataLoader( + valid_dataset, + shuffle=(valid_sampler is None), + sampler=valid_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) + + eval_dataloader = DataLoader( + eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, pin_memory=True + ) lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100) strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) - model = strategy_dict['model'] - optim = strategy_dict['optimizer'] - lr_scheduler = strategy_dict['lr_scheduler'] - trainer = RewardModelTrainer(model=model, - strategy=strategy, - optim=optim, - lr_scheduler=lr_scheduler, - loss_fn=loss_fn, - max_epochs=args.max_epochs) + model = strategy_dict["model"] + optim = strategy_dict["optimizer"] + lr_scheduler = strategy_dict["lr_scheduler"] + trainer = RewardModelTrainer( + model=model, + strategy=strategy, + optim=optim, + lr_scheduler=lr_scheduler, + loss_fn=loss_fn, + max_epochs=args.max_epochs, + ) trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader) # save model checkpoint after fitting on only rank0 strategy.save_model(model, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, - 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'], - default='colossalai_zero2') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') - parser.add_argument('--tokenizer', type=str, default=None) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--model_path', type=str, default=None) - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--dataset', - type=str, - choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], - default='Dahoas/rm-static') - parser.add_argument('--subset', type=lambda x: None if x == 'None' else x, default=None) - parser.add_argument('--save_path', type=str, default='rm_ckpt') - parser.add_argument('--max_epochs', type=int, default=1) - parser.add_argument('--batch_size', type=int, default=1) - parser.add_argument('--max_len', type=int, default=512) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp']) - parser.add_argument('--test', type=bool, default=False) + parser.add_argument( + "--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="colossalai_zero2" + ) + parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--model_path", type=str, default=None) + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument( + "--dataset", type=str, choices=["Anthropic/hh-rlhf", "Dahoas/rm-static"], default="Dahoas/rm-static" + ) + parser.add_argument("--subset", type=lambda x: None if x == "None" else x, default=None) + parser.add_argument("--save_path", type=str, default="rm_ckpt") + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--max_len", type=int, default=512) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"]) + parser.add_argument("--test", type=bool, default=False) args = parser.parse_args() train(args) diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index f068ea2bf5de..1729abb86a09 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -6,18 +6,18 @@ import torch.distributed as dist from coati.dataset import SFTDataset, SupervisedDataset from coati.models.bloom import BLOOMActor +from coati.models.chatglm import ChatGLMActor +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.gpt import GPTActor from coati.models.llama import LlamaActor from coati.models.opt import OPTActor -from coati.models.chatglm import ChatGLMActor from coati.trainer import SFTTrainer from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from datasets import load_dataset from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.trainer import get_scheduler @@ -28,14 +28,14 @@ def train(args): # configure strategy - if args.strategy == 'ddp': + if args.strategy == "ddp": strategy = DDPStrategy() - elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda') - elif args.strategy == 'colossalai_zero2': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') - elif args.strategy == 'colossalai_zero2_cpu': - strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu') + elif args.strategy == "colossalai_gemini": + strategy = GeminiStrategy(placement_policy="cuda") + elif args.strategy == "colossalai_zero2": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") + elif args.strategy == "colossalai_zero2_cpu": + strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu") else: raise ValueError(f'Unsupported strategy "{args.strategy}"') @@ -44,23 +44,15 @@ def train(args): warnings.warn("Gradient checkpoint is disabled when using LoRA") args.grad_checkpoint = False with strategy.model_init_context(): - if args.model == 'bloom': - model = BLOOMActor(pretrained=args.pretrain, - lora_rank=args.lora_rank, - checkpoint=args.grad_checkpoint) - elif args.model == 'opt': - model = OPTActor(pretrained=args.pretrain, - lora_rank=args.lora_rank, - checkpoint=args.grad_checkpoint) - elif args.model == 'gpt2': - model = GPTActor(pretrained=args.pretrain, - lora_rank=args.lora_rank, - checkpoint=args.grad_checkpoint) - elif args.model == 'llama': - model = LlamaActor(pretrained=args.pretrain, - lora_rank=args.lora_rank, - checkpoint=args.grad_checkpoint) - elif args.model == 'chatglm': + if args.model == "bloom": + model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "opt": + model = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "gpt2": + model = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "llama": + model = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint) + elif args.model == "chatglm": model = ChatGLMActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported model "{args.model}"') @@ -68,144 +60,157 @@ def train(args): model.to(torch.float16).to(torch.cuda.current_device()) # configure tokenizer - if args.model == 'gpt2': - tokenizer = GPT2Tokenizer.from_pretrained( - 'gpt2' if args.tokenizer is None else args.tokenizer) + if args.model == "gpt2": + tokenizer = GPT2Tokenizer.from_pretrained("gpt2" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'bloom': + elif args.model == "bloom": tokenizer = BloomTokenizerFast.from_pretrained( - 'bigscience/bloom-560m' if args.tokenizer is None else args.tokenizer) + "bigscience/bloom-560m" if args.tokenizer is None else args.tokenizer + ) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained( - "facebook/opt-350m" if args.tokenizer is None else args.tokenizer) + elif args.model == "opt": + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m" if args.tokenizer is None else args.tokenizer) tokenizer.pad_token = tokenizer.eos_token - elif args.model == 'llama': + elif args.model == "llama": tokenizer = LlamaTokenizer.from_pretrained( - "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer) - tokenizer.eos_token = '<\s>' + "hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer + ) + tokenizer.eos_token = "<\s>" tokenizer.pad_token = tokenizer.unk_token - elif args.model == 'chatglm': + elif args.model == "chatglm": tokenizer = ChatGLMTokenizer.from_pretrained( - "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True) + "THUDM/chatglm-6b" if args.tokenizer is None else args.tokenizer, trust_remote_code=True + ) else: raise ValueError(f'Unsupported model "{args.model}"') - if args.model == 'llama' and args.strategy == 'colossalai_gemini': + if args.model == "llama" and args.strategy == "colossalai_gemini": # this is a hack to deal with the resized embedding # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility for name, param in model.named_parameters(): if not isinstance(param, ColoParameter): - sub_module_name = '.'.join(name.split('.')[:-1]) - weight_name = name.split('.')[-1] + sub_module_name = ".".join(name.split(".")[:-1]) + weight_name = name.split(".")[-1] sub_module = model.get_submodule(sub_module_name) setattr(sub_module, weight_name, ColoParameter(param)) # configure optimizer - if args.strategy.startswith('colossalai'): + if args.strategy.startswith("colossalai"): optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) else: optim = Adam(model.parameters(), lr=args.lr) logger = get_dist_logger() # configure dataset - if args.dataset == 'yizhongw/self_instruct': - train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') - eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') + if args.dataset == "yizhongw/self_instruct": + train_data = load_dataset(args.dataset, "super_natural_instructions", split="train") + eval_data = load_dataset(args.dataset, "super_natural_instructions", split="test") train_dataset = SFTDataset(train_data, tokenizer, args.max_len) eval_dataset = SFTDataset(eval_data, tokenizer, args.max_len) else: - train_dataset = SupervisedDataset(tokenizer=tokenizer, - data_path=args.dataset, - max_datasets_size=args.max_datasets_size, - max_length=args.max_len) + train_dataset = SupervisedDataset( + tokenizer=tokenizer, + data_path=args.dataset, + max_datasets_size=args.max_datasets_size, + max_length=args.max_len, + ) eval_dataset = None if dist.is_initialized() and dist.get_world_size() > 1: - train_sampler = DistributedSampler(train_dataset, - shuffle=True, - seed=42, - drop_last=True, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + train_sampler = DistributedSampler( + train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) if eval_dataset is not None: - eval_sampler = DistributedSampler(eval_dataset, - shuffle=False, - seed=42, - drop_last=False, - rank=dist.get_rank(), - num_replicas=dist.get_world_size()) + eval_sampler = DistributedSampler( + eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) else: train_sampler = None eval_sampler = None - train_dataloader = DataLoader(train_dataset, - shuffle=(train_sampler is None), - sampler=train_sampler, - batch_size=args.batch_size, - pin_memory=True) + train_dataloader = DataLoader( + train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) if eval_dataset is not None: - eval_dataloader = DataLoader(eval_dataset, - shuffle=(eval_sampler is None), - sampler=eval_sampler, - batch_size=args.batch_size, - pin_memory=True) + eval_dataloader = DataLoader( + eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + pin_memory=True, + ) else: eval_dataloader = None num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) - lr_scheduler = get_scheduler("cosine", - optim, - num_warmup_steps=math.ceil(max_steps * 0.03), - num_training_steps=max_steps) + lr_scheduler = get_scheduler( + "cosine", optim, num_warmup_steps=math.ceil(max_steps * 0.03), num_training_steps=max_steps + ) strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) - model = strategy_dict['model'] - optim = strategy_dict['optimizer'] - lr_scheduler = strategy_dict['lr_scheduler'] - trainer = SFTTrainer(model=model, - strategy=strategy, - optim=optim, - lr_scheduler=lr_scheduler, - max_epochs=args.max_epochs, - accumulation_steps=args.accumulation_steps) - - trainer.fit(train_dataloader=train_dataloader, - eval_dataloader=eval_dataloader, - logger=logger, - use_wandb=args.use_wandb) + model = strategy_dict["model"] + optim = strategy_dict["optimizer"] + lr_scheduler = strategy_dict["lr_scheduler"] + trainer = SFTTrainer( + model=model, + strategy=strategy, + optim=optim, + lr_scheduler=lr_scheduler, + max_epochs=args.max_epochs, + accumulation_steps=args.accumulation_steps, + ) + + trainer.fit( + train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, logger=logger, use_wandb=args.use_wandb + ) # save model checkpoint after fitting on only rank0 strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, - 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + strategy.save_optimizer( + trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--strategy', - choices=['ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'], - default='colossalai_zero2') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama', 'chatglm'], default='bloom') - parser.add_argument('--tokenizer', type=str, default=None) - parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--dataset', type=str, default=None) - parser.add_argument('--max_datasets_size', type=int, default=None) - parser.add_argument('--save_path', type=str, default='output') - parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--max_epochs', type=int, default=3) - parser.add_argument('--batch_size', type=int, default=4) - parser.add_argument('--max_len', type=int, default=512) - parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") - parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") - parser.add_argument('--lr', type=float, default=5e-6) - parser.add_argument('--accumulation_steps', type=int, default=8) - parser.add_argument('--use_wandb', default=False, action='store_true') - parser.add_argument('--grad_checkpoint', default=False, action='store_true') + parser.add_argument( + "--strategy", + choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_zero2_cpu"], + default="colossalai_zero2", + ) + parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama", "chatglm"], default="bloom") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--pretrain", type=str, default=None) + parser.add_argument("--dataset", type=str, default=None) + parser.add_argument("--max_datasets_size", type=int, default=None) + parser.add_argument("--save_path", type=str, default="output") + parser.add_argument("--need_optim_ckpt", type=bool, default=False) + parser.add_argument("--max_epochs", type=int, default=3) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--max_len", type=int, default=512) + parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log") + parser.add_argument("--lr", type=float, default=5e-6) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--grad_checkpoint", default=False, action="store_true") args = parser.parse_args() train(args) diff --git a/applications/Chat/inference/benchmark.py b/applications/Chat/inference/benchmark.py index 438a1e3ef1c7..dbb5490a63dc 100644 --- a/applications/Chat/inference/benchmark.py +++ b/applications/Chat/inference/benchmark.py @@ -84,28 +84,34 @@ def evaluate( if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - 'pretrained', - help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') - parser.add_argument('--quant', - choices=['8bit', '4bit'], - default=None, - help='Quantization mode. Default: None (no quantization, fp16).') + "pretrained", + help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.", + ) + parser.add_argument( + "--quant", + choices=["8bit", "4bit"], + default=None, + help="Quantization mode. Default: None (no quantization, fp16).", + ) parser.add_argument( - '--gptq_checkpoint', + "--gptq_checkpoint", default=None, - help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') - parser.add_argument('--gptq_group_size', - type=int, - default=128, - help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') + help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.", + ) + parser.add_argument( + "--gptq_group_size", + type=int, + default=128, + help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.", + ) args = parser.parse_args() - if args.quant == '4bit': - assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' + if args.quant == "4bit": + assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint." tokenizer = AutoTokenizer.from_pretrained(args.pretrained) - if args.quant == '4bit': + if args.quant == "4bit": with low_resource_init(): config = LlamaConfig.from_pretrained(args.pretrained) model = LlamaForCausalLM(config) @@ -114,12 +120,12 @@ def evaluate( else: model = LlamaForCausalLM.from_pretrained( args.pretrained, - load_in_8bit=(args.quant == '8bit'), + load_in_8bit=(args.quant == "8bit"), torch_dtype=torch.float16, device_map="auto", ) - if args.quant != '8bit': - model.half() # seems to fix bugs for some users. + if args.quant != "8bit": + model.half() # seems to fix bugs for some users. model.eval() total_tokens = 0 @@ -129,7 +135,7 @@ def evaluate( resp, tokens = evaluate(model, tokenizer, instruction, temperature=0.2, num_beams=1) total_tokens += tokens print(f"Response: {resp}") - print('\n----------------------------\n') + print("\n----------------------------\n") duration = time() - start - print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s') - print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB') + print(f"Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s") + print(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB") diff --git a/applications/Chat/inference/locustfile.py b/applications/Chat/inference/locustfile.py index 9443d4b99180..333262e538ac 100644 --- a/applications/Chat/inference/locustfile.py +++ b/applications/Chat/inference/locustfile.py @@ -1,26 +1,26 @@ -from json import JSONDecodeError - from locust import HttpUser, task -samples = [[ - dict( - instruction='Who is the best player in the history of NBA?', - response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - dict(instruction='continue this talk', response=''), -], [ - dict(instruction='Who is the best player in the history of NBA?', response=''), -]] +samples = [ + [ + dict( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + dict(instruction="continue this talk", response=""), + ], + [ + dict(instruction="Who is the best player in the history of NBA?", response=""), + ], +] class GenerationUser(HttpUser): - @task def generate(self): for sample in samples: - data = {'max_new_tokens': 64, 'history': sample} - with self.client.post('/generate', json=data, catch_response=True) as response: + data = {"max_new_tokens": 64, "history": sample} + with self.client.post("/generate", json=data, catch_response=True) as response: if response.status_code in (200, 406): response.success() else: - response.failure('Response wrong') + response.failure("Response wrong") diff --git a/applications/Chat/inference/server.py b/applications/Chat/inference/server.py index 9d6b7fabef54..7c6a61b9e7f2 100644 --- a/applications/Chat/inference/server.py +++ b/applications/Chat/inference/server.py @@ -16,7 +16,7 @@ from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM from utils import ChatPromptProcessor, Dialogue, LockedIterator, load_json, sample_streamingly, update_model_kwargs_fn -CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' +CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions." MAX_LEN = 512 running_lock = Lock() @@ -36,11 +36,11 @@ class GenerationTaskReq(BaseModel): app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # set CORS -origin_spec_from_env = os.environ.get('CORS_ORIGIN', None) +origin_spec_from_env = os.environ.get("CORS_ORIGIN", None) if origin_spec_from_env is not None: # allow CORS from the specified origins - origins = os.environ['CORS_ORIGIN'].split(',') + origins = os.environ["CORS_ORIGIN"].split(",") else: # allow CORS from all origins origins = ["*"] @@ -58,13 +58,13 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} # TODO(ver217): streaming generation does not support repetition_penalty now model_kwargs = { - 'max_generate_tokens': max_new_tokens, - 'early_stopping': True, - 'top_k': top_k, - 'top_p': top_p, - 'temperature': temperature, - 'prepare_inputs_fn': model.prepare_inputs_for_generation, - 'update_model_kwargs_fn': update_model_kwargs_fn, + "max_generate_tokens": max_new_tokens, + "early_stopping": True, + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "prepare_inputs_fn": model.prepare_inputs_for_generation, + "update_model_kwargs_fn": update_model_kwargs_fn, } is_first_word = True generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock) @@ -81,9 +81,9 @@ def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature): if is_first_word: out_string = out_string.lstrip() is_first_word = False - elif current_sub_tokens[0].startswith('▁'): + elif current_sub_tokens[0].startswith("▁"): # whitespace will be ignored by the frontend - out_string = ' ' + out_string + out_string = " " + out_string yield out_string @@ -92,32 +92,33 @@ async def event_generator(request: Request, generator: Generator): if await request.is_disconnected(): break try: - yield {'event': 'generate', 'data': next(generator)} + yield {"event": "generate", "data": next(generator)} except StopIteration: - yield {'event': 'end', 'data': ''} + yield {"event": "end", "data": ""} break -@app.post('/generate/stream') -@limiter.limit('1/second') +@app.post("/generate/stream") +@limiter.limit("1/second") def generate(data: GenerationTaskReq, request: Request): prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) event_source = event_generator( - request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature)) + request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature) + ) return EventSourceResponse(event_source) -@app.post('/generate') -@limiter.limit('1/second') +@app.post("/generate") +@limiter.limit("1/second") def generate_no_stream(data: GenerationTaskReq, request: Request): prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens) if prompt_processor.has_censored_words(prompt): return prompt_processor.SAFE_RESPONSE inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()} with running_lock: - output = model.generate(**inputs, **data.dict(exclude={'history'})) + output = model.generate(**inputs, **data.dict(exclude={"history"})) output = output.cpu() - prompt_len = inputs['input_ids'].size(1) + prompt_len = inputs["input_ids"].size(1) response = output[0, prompt_len:] out_string = tokenizer.decode(response, skip_special_tokens=True) out_string = prompt_processor.postprocess_output(out_string) @@ -126,32 +127,40 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): return out_string -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - 'pretrained', - help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.') - parser.add_argument('--quant', - choices=['8bit', '4bit'], - default=None, - help='Quantization mode. Default: None (no quantization, fp16).') + "pretrained", + help="Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.", + ) parser.add_argument( - '--gptq_checkpoint', + "--quant", + choices=["8bit", "4bit"], default=None, - help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.') - parser.add_argument('--gptq_group_size', - type=int, - default=128, - help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.') - parser.add_argument('--http_host', default='0.0.0.0') - parser.add_argument('--http_port', type=int, default=7070) - parser.add_argument('--profanity_file', - default=None, - help='Path to profanity words list. It should be a JSON file containing a list of words.') + help="Quantization mode. Default: None (no quantization, fp16).", + ) + parser.add_argument( + "--gptq_checkpoint", + default=None, + help="Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.", + ) + parser.add_argument( + "--gptq_group_size", + type=int, + default=128, + help="Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.", + ) + parser.add_argument("--http_host", default="0.0.0.0") + parser.add_argument("--http_port", type=int, default=7070) + parser.add_argument( + "--profanity_file", + default=None, + help="Path to profanity words list. It should be a JSON file containing a list of words.", + ) args = parser.parse_args() - if args.quant == '4bit': - assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.' + if args.quant == "4bit": + assert args.gptq_checkpoint is not None, "Please specify a GPTQ checkpoint." tokenizer = AutoTokenizer.from_pretrained(args.pretrained) @@ -161,7 +170,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): censored_words = [] prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words) - if args.quant == '4bit': + if args.quant == "4bit": with low_resource_init(): config = LlamaConfig.from_pretrained(args.pretrained) model = LlamaForCausalLM(config) @@ -170,12 +179,12 @@ def generate_no_stream(data: GenerationTaskReq, request: Request): else: model = LlamaForCausalLM.from_pretrained( args.pretrained, - load_in_8bit=(args.quant == '8bit'), + load_in_8bit=(args.quant == "8bit"), torch_dtype=torch.float16, device_map="auto", ) - if args.quant != '8bit': - model.half() # seems to fix bugs for some users. + if args.quant != "8bit": + model.half() # seems to fix bugs for some users. model.eval() config = uvicorn.Config(app, host=args.http_host, port=args.http_port) diff --git a/applications/Chat/inference/tests/test_chat_prompt.py b/applications/Chat/inference/tests/test_chat_prompt.py index 23028d4959cb..9835e71894c6 100644 --- a/applications/Chat/inference/tests/test_chat_prompt.py +++ b/applications/Chat/inference/tests/test_chat_prompt.py @@ -3,41 +3,49 @@ from transformers import AutoTokenizer from utils import ChatPromptProcessor, Dialogue -CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.' -tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH']) +CONTEXT = "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions." +tokenizer = AutoTokenizer.from_pretrained(os.environ["PRETRAINED_PATH"]) samples = [ - ([ - Dialogue( - instruction='Who is the best player in the history of NBA?', - response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - Dialogue(instruction='continue this talk', response=''), - ], 128, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + ( + [ + Dialogue( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + Dialogue(instruction="continue this talk", response=""), + ], + 128, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n", ), - ([ - Dialogue( - instruction='Who is the best player in the history of NBA?', - response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - Dialogue(instruction='continue this talk', response=''), - ], 200, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n' + ( + [ + Dialogue( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + Dialogue(instruction="continue this talk", response=""), + ], + 200, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n", ), - ([ - Dialogue( - instruction='Who is the best player in the history of NBA?', - response='The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1' - ), - Dialogue(instruction='continue this talk', response=''), - ], 211, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n' + ( + [ + Dialogue( + instruction="Who is the best player in the history of NBA?", + response="The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1", + ), + Dialogue(instruction="continue this talk", response=""), + ], + 211, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n", ), - ([ - Dialogue(instruction='Who is the best player in the history of NBA?', response=''), - ], 128, - 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n' + ( + [ + Dialogue(instruction="Who is the best player in the history of NBA?", response=""), + ], + 128, + "Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n", ), ] @@ -49,5 +57,5 @@ def test_chat_prompt_processor(): assert prompt == result -if __name__ == '__main__': +if __name__ == "__main__": test_chat_prompt_processor() diff --git a/applications/Chat/inference/utils.py b/applications/Chat/inference/utils.py index e8e7b05ac719..af018adf6e9d 100644 --- a/applications/Chat/inference/utils.py +++ b/applications/Chat/inference/utils.py @@ -20,9 +20,9 @@ from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper -def prepare_logits_processor(top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None) -> LogitsProcessorList: +def prepare_logits_processor( + top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None +) -> LogitsProcessorList: processor_list = LogitsProcessorList() if temperature is not None and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) @@ -41,29 +41,30 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool: return unfinished_sequences.max() == 0 -def sample_streamingly(model: nn.Module, - input_ids: torch.Tensor, - max_generate_tokens: int, - early_stopping: bool = False, - eos_token_id: Optional[int] = None, - pad_token_id: Optional[int] = None, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, - **model_kwargs) -> Generator: - +def sample_streamingly( + model: nn.Module, + input_ids: torch.Tensor, + max_generate_tokens: int, + early_stopping: bool = False, + eos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None, + **model_kwargs, +) -> Generator: logits_processor = prepare_logits_processor(top_k, top_p, temperature) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) for _ in range(max_generate_tokens): - model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { - 'input_ids': input_ids - } + model_inputs = ( + prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids} + ) outputs = model(**model_inputs) - next_token_logits = outputs['logits'][:, -1, :] + next_token_logits = outputs["logits"][:, -1, :] # pre-process distribution next_token_logits = logits_processor(input_ids, next_token_logits) # sample @@ -107,25 +108,26 @@ def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict: if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) return model_kwargs class Dialogue(BaseModel): - instruction: str = Field(min_length=1, example='Count up from 1 to 500.') - response: str = Field(example='') + instruction: str = Field(min_length=1, example="Count up from 1 to 500.") + response: str = Field(example="") -def _format_dialogue(instruction: str, response: str = ''): - return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}' +def _format_dialogue(instruction: str, response: str = ""): + return f"\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}" -STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S)) +STOP_PAT = re.compile(r"(###|instruction:).*", flags=(re.I | re.S)) class ChatPromptProcessor: - SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.' + SAFE_RESPONSE = "The input/response contains inappropriate content, please rephrase your prompt." def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str] = []): self.tokenizer = tokenizer @@ -138,42 +140,48 @@ def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str: if self.context_len is None: - self.context_len = len(self.tokenizer(self.context)['input_ids']) + self.context_len = len(self.tokenizer(self.context)["input_ids"]) if self.dialogue_placeholder_len is None: self.dialogue_placeholder_len = len( - self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids']) + self.tokenizer(_format_dialogue(""), add_special_tokens=False)["input_ids"] + ) prompt = self.context # the last dialogue must be in the prompt last_dialogue = history.pop() # the response of the last dialogue is empty - assert last_dialogue.response == '' - if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False) - ['input_ids']) + max_new_tokens + self.context_len >= self.max_len: + assert last_dialogue.response == "" + if ( + len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)["input_ids"]) + + max_new_tokens + + self.context_len + >= self.max_len + ): # to avoid truncate placeholder, apply truncate to the original instruction - instruction_truncated = self.tokenizer(last_dialogue.instruction, - add_special_tokens=False, - truncation=True, - max_length=(self.max_len - max_new_tokens - self.context_len - - self.dialogue_placeholder_len))['input_ids'] + instruction_truncated = self.tokenizer( + last_dialogue.instruction, + add_special_tokens=False, + truncation=True, + max_length=(self.max_len - max_new_tokens - self.context_len - self.dialogue_placeholder_len), + )["input_ids"] instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip() prompt += _format_dialogue(instruction_truncated) return prompt - res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids']) + res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)["input_ids"]) rows = [] for dialogue in history[::-1]: text = _format_dialogue(dialogue.instruction, dialogue.response) - cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids']) + cur_len = len(self.tokenizer(text, add_special_tokens=False)["input_ids"]) if res_len - cur_len < 0: break res_len -= cur_len rows.insert(0, text) - prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction) + prompt += "".join(rows) + _format_dialogue(last_dialogue.instruction) return prompt def postprocess_output(self, output: str) -> str: - output = STOP_PAT.sub('', output) + output = STOP_PAT.sub("", output) return output.strip() def has_censored_words(self, text: str) -> bool: @@ -184,7 +192,6 @@ def has_censored_words(self, text: str) -> bool: class LockedIterator: - def __init__(self, it, lock: Lock) -> None: self.lock = lock self.it = iter(it) diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt index eb1a77875acb..809fbd4bb86b 100644 --- a/applications/Chat/requirements-test.txt +++ b/applications/Chat/requirements-test.txt @@ -1,2 +1,2 @@ pytest -colossalai==0.3.1 \ No newline at end of file +colossalai==0.3.1 diff --git a/applications/Chat/setup.py b/applications/Chat/setup.py index a285a6dff4bf..eb44b6203ef8 100644 --- a/applications/Chat/setup.py +++ b/applications/Chat/setup.py @@ -2,40 +2,42 @@ def fetch_requirements(path): - with open(path, 'r') as fd: + with open(path, "r") as fd: return [r.strip() for r in fd.readlines()] def fetch_readme(): - with open('README.md', encoding='utf-8') as f: + with open("README.md", encoding="utf-8") as f: return f.read() def fetch_version(): - with open('version.txt', 'r') as f: + with open("version.txt", "r") as f: return f.read().strip() setup( - name='coati', + name="coati", version=fetch_version(), - packages=find_packages(exclude=( - 'tests', - 'benchmarks', - '*.egg-info', - )), - description='Colossal-AI Talking Intelligence', + packages=find_packages( + exclude=( + "tests", + "benchmarks", + "*.egg-info", + ) + ), + description="Colossal-AI Talking Intelligence", long_description=fetch_readme(), - long_description_content_type='text/markdown', - license='Apache Software License 2.0', - url='https://github.com/hpcaitech/Coati', - install_requires=fetch_requirements('requirements.txt'), - python_requires='>=3.6', + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech/Coati", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Environment :: GPU :: NVIDIA CUDA', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: System :: Distributed Computing', + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", ], ) diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index 3a3bf5b19cb8..e3058be2e67c 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -22,10 +22,7 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict: return dict(input_ids=input_ids, attention_mask=attention_mask) -def train_step(strategy: Strategy, - actor: GPTActor, - actor_optim: HybridAdam, - batch_size: int = 8): +def train_step(strategy: Strategy, actor: GPTActor, actor_optim: HybridAdam, batch_size: int = 8): data = get_data(batch_size) action_mask = torch.ones_like(data["attention_mask"], dtype=torch.bool) actor_output = actor(data["input_ids"], data["attention_mask"]) @@ -35,8 +32,7 @@ def train_step(strategy: Strategy, strategy.optimizer_step(actor_optim) -def run_test_checkpoint(strategy_name: str, - shard: bool): +def run_test_checkpoint(strategy_name: str, shard: bool): if strategy_name == "ddp": strategy = DDPStrategy() elif strategy_name == "colossalai_gemini": @@ -60,11 +56,9 @@ def run_test_checkpoint(strategy_name: str, dist.broadcast_object_list(rank0_dirname) rank0_dirname = rank0_dirname[0] - model_path = os.path.join( - rank0_dirname, "model" if shard else f"model.pt") + model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt") strategy.save_model(actor, model_path, only_rank0=not shard) - optim_path = os.path.join( - rank0_dirname, "optim" if shard else "optim.pt") + optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt") strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard) dist.barrier() @@ -75,11 +69,7 @@ def run_test_checkpoint(strategy_name: str, train_step(strategy, actor, actor_optim) -def run_dist(rank: int, - world_size: int, - port: int, - strategy_name: str, - shard: bool): +def run_dist(rank: int, world_size: int, port: int, strategy_name: str, shard: bool): os.environ["RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) @@ -93,13 +83,8 @@ def run_dist(rank: int, @pytest.mark.parametrize("strategy_name", ["ddp", "colossalai_gemini", "colossalai_zero2"]) @pytest.mark.parametrize("shard", [False, True]) @rerun_if_address_is_in_use() -def test_checkpoint(world_size: int, - strategy_name: str, - shard: bool): - spawn(run_dist, - world_size, - strategy_name=strategy_name, - shard=shard) +def test_checkpoint(world_size: int, strategy_name: str, shard: bool): + spawn(run_dist, world_size, strategy_name=strategy_name, shard=shard) if __name__ == "__main__": diff --git a/applications/Chat/tests/test_dataset.py b/applications/Chat/tests/test_dataset.py index f9dee1bae935..3de2cc528967 100644 --- a/applications/Chat/tests/test_dataset.py +++ b/applications/Chat/tests/test_dataset.py @@ -8,62 +8,40 @@ from coati.dataset.prompt_dataset import PromptDataset from coati.dataset.reward_dataset import HhRlhfDataset, RmStaticDataset from coati.dataset.sft_dataset import IGNORE_INDEX, SFTDataset, SupervisedDataset +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from datasets import load_dataset from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, PreTrainedTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer + SFT_DATASET = [ { - "instruction": - "Provide a list of the top 10 most popular mobile games in Asia", - "input": - "", - "output": - "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", - "id": - 0 + "instruction": "Provide a list of the top 10 most popular mobile games in Asia", + "input": "", + "output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved", + "id": 0, }, { - "instruction": - "Please provide an action plan for reducing carbon footprint on a corporate level", - "input": - "", - "output": - "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", - "id": - 1 + "instruction": "Please provide an action plan for reducing carbon footprint on a corporate level", + "input": "", + "output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.", + "id": 1, }, { - "instruction": - "Write a persuasive email to your boss explaining why you should have a pay raise", - "input": - "", - "output": - "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", - "id": - 2 + "instruction": "Write a persuasive email to your boss explaining why you should have a pay raise", + "input": "", + "output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]", + "id": 2, }, ] PROMPT_DATASET = [ { - "instruction": - "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"", - "id": - 0 - }, - { - "instruction": "Write a descriptive paragraph about a memorable vacation you went on", - "id": 1 - }, - { - "instruction": "Write a persuasive essay arguing why homework should be banned in schools", - "id": 2 - }, - { - "instruction": "Create a chart comparing the statistics on student debt in the United States.", - "id": 3 + "instruction": 'Edit this paragraph to make it more concise: "Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends."', + "id": 0, }, + {"instruction": "Write a descriptive paragraph about a memorable vacation you went on", "id": 1}, + {"instruction": "Write a persuasive essay arguing why homework should be banned in schools", "id": 2}, + {"instruction": "Create a chart comparing the statistics on student debt in the United States.", "id": 3}, ] @@ -120,10 +98,12 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): json.dump(PROMPT_DATASET, f) tokenizer = make_tokenizer(model) assert tokenizer.padding_side in ("left", "right") - prompt_dataset = PromptDataset(data_path=os.path.join(tmp_dir, dataset_name), - tokenizer=tokenizer, - max_datasets_size=max_datasets_size, - max_length=max_length) + prompt_dataset = PromptDataset( + data_path=os.path.join(tmp_dir, dataset_name), + tokenizer=tokenizer, + max_datasets_size=max_datasets_size, + max_length=max_length, + ) assert len(prompt_dataset) == min(max_datasets_size, len(PROMPT_DATASET)) for i in range(len(prompt_dataset)): assert isinstance(prompt_dataset[i], dict) @@ -137,14 +117,14 @@ def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int): @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"]) -@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), - ("Dahoas/rm-static", None)]) +@pytest.mark.parametrize( + ["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"), ("Dahoas/rm-static", None)] +) @pytest.mark.parametrize("max_datasets_size", [32]) @pytest.mark.parametrize("max_length", [32, 1024]) def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int): data = load_dataset(dataset_path, data_dir=subset) - assert max_datasets_size <= len(data["train"]) \ - and max_datasets_size <= len(data["test"]) + assert max_datasets_size <= len(data["train"]) and max_datasets_size <= len(data["test"]) train_data = data["train"].select(range(max_datasets_size)) test_data = data["test"].select(range(max_datasets_size)) tokenizer = make_tokenizer(model) @@ -162,8 +142,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma assert len(train_dataset) == len(test_dataset) == max_datasets_size for i in range(max_datasets_size): chosen_ids, c_mask, reject_ids, r_mask = train_dataset[i] - assert chosen_ids.shape == c_mask.shape == \ - reject_ids.shape == r_mask.shape == torch.Size([max_length]) + assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length]) c_mask = c_mask.to(torch.bool) r_mask = r_mask.to(torch.bool) if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: @@ -180,8 +159,7 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma assert torch.all(r_mask) chosen_ids, c_mask, reject_ids, r_mask = test_dataset[i] - assert chosen_ids.shape == c_mask.shape == \ - reject_ids.shape == r_mask.shape == torch.Size([max_length]) + assert chosen_ids.shape == c_mask.shape == reject_ids.shape == r_mask.shape == torch.Size([max_length]) c_mask = c_mask.to(torch.bool) r_mask = r_mask.to(torch.bool) if chosen_ids.masked_select(c_mask)[-1] == tokenizer.eos_token_id: @@ -198,7 +176,6 @@ def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], ma assert torch.all(r_mask) - @pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"]) @pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None]) @pytest.mark.parametrize("max_dataset_size", [2]) @@ -214,10 +191,12 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: dataset_name = "sft_dataset.json" with open(os.path.join(tmp_dir, dataset_name), "w") as f: json.dump(SFT_DATASET, f) - sft_dataset = SupervisedDataset(tokenizer=tokenizer, - data_path=os.path.join(tmp_dir, dataset_name), - max_datasets_size=max_dataset_size, - max_length=max_length) + sft_dataset = SupervisedDataset( + tokenizer=tokenizer, + data_path=os.path.join(tmp_dir, dataset_name), + max_datasets_size=max_dataset_size, + max_length=max_length, + ) assert len(sft_dataset) == min(max_dataset_size, len(SFT_DATASET)) if isinstance(tokenizer, ChatGLMTokenizer): @@ -227,20 +206,19 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: input_ids = sft_dataset[i]["input_ids"] labels = sft_dataset[i]["labels"] assert input_ids.shape == labels.shape == torch.Size([max_length]) - + ignore_mask = labels == IGNORE_INDEX assert input_ids.masked_select(torch.logical_not(ignore_mask))[0] == tokenizer.bos_token_id check_content(input_ids.masked_select(torch.logical_not(ignore_mask)), tokenizer, model) return - + for i in range(max_dataset_size): assert isinstance(sft_dataset[i], dict) assert list(sft_dataset[i].keys()) == ["input_ids", "labels", "attention_mask"] input_ids = sft_dataset[i]["input_ids"] labels = sft_dataset[i]["labels"] attention_mask = sft_dataset[i]["attention_mask"].to(torch.bool) - assert input_ids.shape == labels.shape == \ - attention_mask.shape == torch.Size([max_length]) + assert input_ids.shape == labels.shape == attention_mask.shape == torch.Size([max_length]) if input_ids.masked_select(attention_mask)[-1] == tokenizer.eos_token_id: check_content(input_ids.masked_select(attention_mask)[:-1], tokenizer, model) assert torch.all(input_ids.masked_select(torch.logical_not(attention_mask)) == tokenizer.pad_token_id) @@ -254,13 +232,8 @@ def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: if __name__ == "__main__": test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256) - test_reward_dataset(model="gpt2", - dataset_path="Anthropic/hh-rlhf", - subset="harmless-base", - max_datasets_size=8, - max_length=256) - - test_prompt_dataset(model="opt", - max_datasets_size=2, - max_length=128) + test_reward_dataset( + model="gpt2", dataset_path="Anthropic/hh-rlhf", subset="harmless-base", max_datasets_size=8, max_length=256 + ) + test_prompt_dataset(model="opt", max_datasets_size=2, max_length=128) diff --git a/applications/Chat/tests/test_experience.py b/applications/Chat/tests/test_experience.py index 071e50b90e8e..d0ea3bbd2ff5 100644 --- a/applications/Chat/tests/test_experience.py +++ b/applications/Chat/tests/test_experience.py @@ -18,7 +18,7 @@ def get_data(batch_size: int, seq_len: int = 10) -> dict: - input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda') + input_ids = torch.randint(0, 50257, (batch_size, seq_len), device="cuda") attention_mask = torch.ones_like(input_ids) return dict(input_ids=input_ids, attention_mask=attention_mask) @@ -37,12 +37,12 @@ def make_and_consume_experience(strategy): EXPERIENCE_BATCH_SIZE = 4 SAMPLE_BATCH_SIZE = 2 - if strategy == 'ddp': + if strategy == "ddp": strategy = DDPStrategy() - elif strategy == 'colossalai-zero2': + elif strategy == "colossalai-zero2": strategy = LowLevelZeroStrategy() - elif strategy == 'colossalai-gemini': - strategy = GeminiStrategy(placement_policy='cuda') + elif strategy == "colossalai-gemini": + strategy = GeminiStrategy(placement_policy="cuda") else: raise ValueError(f'Unsupported strategy "{strategy}"') @@ -58,13 +58,11 @@ def make_and_consume_experience(strategy): # experience of all ranks should be the same for _ in range(2): data = get_data(EXPERIENCE_BATCH_SIZE) - assert gather_and_equal(data['input_ids']) - assert gather_and_equal(data['attention_mask']) - experience = experience_maker.make_experience(**data, - do_sample=True, - max_length=16, - eos_token_id=50256, - pad_token_id=50256) + assert gather_and_equal(data["input_ids"]) + assert gather_and_equal(data["attention_mask"]) + experience = experience_maker.make_experience( + **data, do_sample=True, max_length=16, eos_token_id=50256, pad_token_id=50256 + ) assert gather_and_equal(experience.sequences) assert gather_and_equal(experience.action_log_probs) assert gather_and_equal(experience.values) @@ -75,7 +73,7 @@ def make_and_consume_experience(strategy): data_buffer.append(experience) # data buffer's data should be the same - buffer_size = torch.tensor([len(data_buffer)], device='cuda') + buffer_size = torch.tensor([len(data_buffer)], device="cuda") assert gather_and_equal(buffer_size) for item in data_buffer.items: assert gather_and_equal(item.sequences) @@ -88,7 +86,7 @@ def make_and_consume_experience(strategy): # dataloader of each rank should have the same size and different batch dataloader = strategy.setup_dataloader(data_buffer) - dataloader_size = torch.tensor([len(dataloader)], device='cuda') + dataloader_size = torch.tensor([len(dataloader)], device="cuda") assert gather_and_equal(dataloader_size) for experience in dataloader: assert not gather_and_equal(experience.sequences) @@ -100,21 +98,21 @@ def make_and_consume_experience(strategy): def run_dist(rank, world_size, port, strategy): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = str(port) + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) make_and_consume_experience(strategy) @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@pytest.mark.parametrize('strategy', ['ddp', 'colossalai-zero2', 'colossalai-gemini']) +@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("strategy", ["ddp", "colossalai-zero2", "colossalai-gemini"]) @rerun_if_address_is_in_use() def test_experience(world_size, strategy): spawn(run_dist, world_size, strategy=strategy) -if __name__ == '__main__': - test_experience(2, 'colossalai') +if __name__ == "__main__": + test_experience(2, "colossalai") diff --git a/applications/Chat/tests/test_models.py b/applications/Chat/tests/test_models.py index b98b3615cd28..b2551ff5c0de 100644 --- a/applications/Chat/tests/test_models.py +++ b/applications/Chat/tests/test_models.py @@ -6,15 +6,16 @@ import torch.nn as nn from coati.models.base import Actor, Critic, RewardModel, get_base_model from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic +from coati.models.chatglm import ChatGLMActor +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.generation import generate from coati.models.gpt import GPTRM, GPTActor, GPTCritic -from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM -from coati.models.chatglm import ChatGLMActor +from coati.models.llama import LlamaActor from coati.models.lora import LoraLinear, convert_to_lora_module from coati.models.loss import GPTLMLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer + @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seq_len", [32]) @@ -23,19 +24,24 @@ [ lambda: BLOOMActor(), lambda: GPTActor(), - # HACK: skip llama due to long execution time - # lambda: LlamaActor(), - lambda: OPTActor(), - # lambda: ChatGLMActor(), -]) - -@pytest.mark.parametrize("generate_kwargs", [{ - "max_length": 64, - "use_cache": True, - "do_sample": True, - "temperature": 1.0, - "top_k": 50, -}]) + # HACK: skip llama due to long execution time + # lambda: LlamaActor(), + lambda: OPTActor(), + # lambda: ChatGLMActor(), + ], +) +@pytest.mark.parametrize( + "generate_kwargs", + [ + { + "max_length": 64, + "use_cache": True, + "do_sample": True, + "temperature": 1.0, + "top_k": 50, + } + ], +) def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]): actor = actor_maker() input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda() @@ -56,7 +62,7 @@ def test_utils(): "kl_coef": 1.0, "log_probs": torch.randn((batch_size, num_labels)), "log_probs_base": torch.randn((batch_size, num_labels)), - "action_mask": torch.randint(0, 2, (batch_size, num_labels)) + "action_mask": torch.randint(0, 2, (batch_size, num_labels)), } fn_output = compute_reward(**fn_input) assert fn_output.shape == (batch_size,) @@ -66,9 +72,7 @@ def test_utils(): num_labels = 10 num_actions = 2 fn_input = { - "output": { - "logits": torch.randn((batch_size, seq_len, num_labels)) - }, + "output": {"logits": torch.randn((batch_size, seq_len, num_labels))}, "sequences": torch.randint(0, num_labels, (batch_size, seq_len)), "num_actions": num_actions, } @@ -105,8 +109,9 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int): assert isinstance(lora_model[i], LoraLinear) assert torch.allclose(old_model[i].weight, lora_model[i].weight) assert torch.allclose(old_model[i].bias, lora_model[i].bias) - assert not torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, - lora_model[i].lora_B @ lora_model[i].lora_A) + assert not torch.allclose( + old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A + ) @pytest.mark.parametrize("batch_size", [8]) @@ -116,54 +121,60 @@ def test_lora(lora_rank: int, num_dim: int, num_layers: int): [ lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), lambda: (GPTActor(), GPTCritic(), GPTRM()), - # HACK: skip llama due to long execution time - # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), - lambda: (OPTActor(), OPTCritic(), OPTRM()), - lambda: (ChatGLMActor(), None, None), -]) + # HACK: skip llama due to long execution time + # lambda: (LlamaActor(), LlamaCritic(), LlamaRM()), + lambda: (OPTActor(), OPTCritic(), OPTRM()), + lambda: (ChatGLMActor(), None, None), + ], +) @torch.no_grad() -def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], - batch_size: int, - seq_len: int): +def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]], batch_size: int, seq_len: int): actor_input = { "input_ids": torch.randint(0, 100, (batch_size, seq_len)), - "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), } critic_input = { "sequences": torch.randint(0, 100, (batch_size, seq_len)), "action_mask": torch.randint(0, 2, (batch_size, seq_len)), - "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), } rm_input = { "sequences": torch.randint(0, 100, (batch_size, seq_len)), - "attention_mask": torch.randint(0, 2, (batch_size, seq_len)) + "attention_mask": torch.randint(0, 2, (batch_size, seq_len)), } actor, critic, rm = models_maker() if isinstance(actor, ChatGLMActor): actor = actor.float() - tokenizer = ChatGLMTokenizer.from_pretrained( "THUDM/chatglm-6b", trust_remote_code=True) + tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) chatglm_special_token = torch.tensor([tokenizer.gmask_token_id, tokenizer.bos_token_id]).repeat(batch_size, 1) - actor_input ={ - "input_ids": torch.cat((torch.randint(0, 100, (batch_size, seq_len//2)), chatglm_special_token, torch.randint(0, 100, (batch_size, seq_len//2 - 2))), dim=1), - "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)) - } + actor_input = { + "input_ids": torch.cat( + ( + torch.randint(0, 100, (batch_size, seq_len // 2)), + chatglm_special_token, + torch.randint(0, 100, (batch_size, seq_len // 2 - 2)), + ), + dim=1, + ), + "attention_mask": torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)), + } assert isinstance(actor, Actor) - base_actor_model = get_base_model(actor) + get_base_model(actor) actor_output = actor(**actor_input) assert actor_output.logits.shape[:2] == (batch_size, seq_len) if critic: assert isinstance(critic, Critic) - base_critic_model = get_base_model(critic) + get_base_model(critic) critic_output = critic(**critic_input) - assert critic_output.shape == (batch_size, ) - + assert critic_output.shape == (batch_size,) + if rm: assert isinstance(rm, RewardModel) - base_rm_model = get_base_model(rm) + get_base_model(rm) rm_output = rm(**rm_input) - assert rm_output.shape == (batch_size, ) + assert rm_output.shape == (batch_size,) @pytest.mark.parametrize("batch_size", [16]) @@ -173,39 +184,59 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int): loss = GPTLMLoss() loss_input = { "logits": torch.randn(batch_size, seq_len, num_labels), - "labels": torch.randint(0, num_labels, (batch_size, seq_len)) + "labels": torch.randint(0, num_labels, (batch_size, seq_len)), } - loss_output = loss(**loss_input) + loss(**loss_input) loss = PolicyLoss() loss_input = { - "log_probs": torch.randn(batch_size,), - "old_log_probs": torch.randn(batch_size,), - "advantages": torch.randn(batch_size,) + "log_probs": torch.randn( + batch_size, + ), + "old_log_probs": torch.randn( + batch_size, + ), + "advantages": torch.randn( + batch_size, + ), } - loss_output = loss(**loss_input) + loss(**loss_input) loss = ValueLoss() loss_input = { - "values": torch.randn(batch_size,), - "old_values": torch.randn(batch_size,), - "reward": torch.randn(batch_size,) + "values": torch.randn( + batch_size, + ), + "old_values": torch.randn( + batch_size, + ), + "reward": torch.randn( + batch_size, + ), } - loss_output = loss(**loss_input) + loss(**loss_input) loss = LogSigLoss() loss_input = { - "chosen_reward": torch.randn(batch_size,), - "reject_reward": torch.randn(batch_size,), + "chosen_reward": torch.randn( + batch_size, + ), + "reject_reward": torch.randn( + batch_size, + ), } - loss_output = loss(**loss_input) + loss(**loss_input) loss = LogExpLoss() loss_input = { - "chosen_reward": torch.randn(batch_size,), - "reject_reward": torch.randn(batch_size,), + "chosen_reward": torch.randn( + batch_size, + ), + "reject_reward": torch.randn( + batch_size, + ), } - loss_output = loss(**loss_input) + loss(**loss_input) if __name__ == "__main__": @@ -218,4 +249,4 @@ def test_loss(batch_size: int, seq_len: int, num_labels: int): test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128) - test_loss(batch_size=8, seq_len=128, num_labels=100) \ No newline at end of file + test_loss(batch_size=8, seq_len=128, num_labels=100) diff --git a/colossalai/__init__.py b/colossalai/__init__.py index fa6f72a605c0..7da55590305b 100644 --- a/colossalai/__init__.py +++ b/colossalai/__init__.py @@ -6,7 +6,7 @@ except ModuleNotFoundError: # this will only happen if the user did not run `pip install` # and directly set PYTHONPATH to use Colossal-AI which is a bad practice - __version__ = '0.0.0' - print('please install Colossal-AI from https://www.colossalai.org/download or from source') + __version__ = "0.0.0" + print("please install Colossal-AI from https://www.colossalai.org/download or from source") -__all__ = ['launch', 'launch_from_openmpi', 'launch_from_slurm', 'launch_from_torch', '__version__'] +__all__ = ["launch", "launch_from_openmpi", "launch_from_slurm", "launch_from_torch", "__version__"] diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py index 4049be79c70f..e8ba88b0406d 100644 --- a/colossalai/_analyzer/_subclasses/_meta_registration.py +++ b/colossalai/_analyzer/_subclasses/_meta_registration.py @@ -3,7 +3,7 @@ # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # for more meta_registrations -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from packaging import version @@ -24,25 +24,23 @@ def new(*args, **kwargs): - return orig_empty(*args, **kwargs, device=torch.device('meta')) + return orig_empty(*args, **kwargs, device=torch.device("meta")) def new_strided(*args, **kwargs): - return orig_empty_strided(*args, **kwargs, device=torch.device('meta')) + return orig_empty_strided(*args, **kwargs, device=torch.device("meta")) def new_like(*args, **kwargs): - return orig_empty_like(*args, **kwargs, device=torch.device('meta')) + return orig_empty_like(*args, **kwargs, device=torch.device("meta")) def register_meta(op, register_dispatcher=True): - def wrapper(f): - def add_func(op): meta_table[op] = f if register_dispatcher: - name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__ try: meta_lib.impl(name, f) except: @@ -54,7 +52,7 @@ def add_func(op): return wrapper -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): # ============================== Convolutions ====================================== # https://github.com/pytorch/pytorch/pull/79834 @register_meta(aten.convolution.default) @@ -69,7 +67,6 @@ def meta_conv( output_padding: List[int], groups: int, ): - def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output @@ -146,7 +143,8 @@ def calc_conv_nd_return_shape( kernel_size[i], stride[i], output_padding_list[i], - )) + ) + ) else: ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) return ret_shape @@ -180,19 +178,39 @@ def pick_memory_format(): shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) mem_fmt = pick_memory_format() - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out @register_meta(aten._convolution.default) - def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], - padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, - *extra_args): + def meta__conv( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + *extra_args, + ): out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) return out @register_meta(aten.convolution_backward.default) - def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, - padding, dilation, transposed, output_padding, groups, output_mask): + def meta_conv_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, + ): return new_like(input), new_like(weight), new((bias_sizes)) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -224,7 +242,6 @@ def meta_cuda_rnn( batch_sizes, dropout_state, ): - is_input_packed = len(batch_sizes) != 0 if is_input_packed: seq_length = len(batch_sizes) @@ -240,8 +257,11 @@ def meta_cuda_rnn( if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: - out_shape = ([mini_batch, seq_length, out_size * - num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] @@ -257,15 +277,21 @@ def meta_cuda_rnn( # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp @register_meta(aten._cudnn_rnn_backward.default) - def meta_cudnn_rnn_backward(input: torch.Tensor, - weight: torch.Tensor, - weight_stride0: int, - hx: torch.Tensor, - cx: Optional[torch.Tensor] = None, - *args, - **kwargs): - return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new( - ()) # (grad_input, grad_weight, grad_hx, grad_cx) + def meta_cudnn_rnn_backward( + input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): + return ( + new_like(input), + new_like(weight), + new_like(hx), + new_like(cx) if cx is not None else new(()), + ) # (grad_input, grad_weight, grad_hx, grad_cx) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp # ============================== Activations ======================================= @@ -278,7 +304,7 @@ def meta_cudnn_rnn_backward(input: torch.Tensor, aten.hardtanh_backward.default, ] - if version.parse(torch.__version__) < version.parse('2.0.0'): + if version.parse(torch.__version__) < version.parse("2.0.0"): _unregistered_ewise += [ aten.prelu_backward.default, ] @@ -296,37 +322,61 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm_backward.default) - def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, train, eps, output_mask): - return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + def meta_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.cudnn_batch_norm.default) def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): n_input = input.size(1) - return new_like(input), new((n_input)), new((n_input)), new( - (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve) + return ( + new_like(input), + new((n_input)), + new((n_input)), + new((0), dtype=torch.uint8), + ) # (output, running_mean, running_var, reserve) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp # NB: CuDNN only implements the backward algorithm for batchnorm # in training mode (evaluation mode batchnorm has a different algorithm), # which is why this doesn't accept a 'training' parameter. @register_meta(aten.cudnn_batch_norm_backward.default) - def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, eps, reserve): - return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) + def meta_cudnn_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + eps, + reserve, + ): + return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm.default) def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): bs, n_input = input.size(0), input.size(1) - return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) + return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm_backward.default) - def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, - grad_input_mask): - return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) + def meta_ln_backward( + dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask + ): + return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta) # ================================== Misc ========================================== # Maybe incorrect @@ -355,8 +405,9 @@ def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Te # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp @register_meta(aten.embedding_dense_backward.default) - def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, - scale_grad_by_freq): + def meta_embedding_dense_backward( + grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq + ): return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout) # ============================== Dropout =========================================== @@ -364,14 +415,14 @@ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tens @register_meta(aten.native_dropout.default) def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False): # notice that mask is bool - return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) + return new_like(input), new_like(input, dtype=torch.bool) # (output, mask) # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp @register_meta(aten.native_dropout_backward.default) def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float): - return new_like(grad) # (grad_in) + return new_like(grad) # (grad_in) - if version.parse(torch.__version__) < version.parse('1.13.0'): + if version.parse(torch.__version__) < version.parse("1.13.0"): # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml @register_meta(aten.eye.m_out) def meta_eye(n: int, m: int, out: torch.Tensor): @@ -385,24 +436,28 @@ def meta_index_Tensor(self, indices): result: List[Optional[torch.Tensor]] = [] for i, index in enumerate(indices): if index is not None: - assert index.dtype in [torch.long, torch.int8, torch.bool],\ - "tensors used as indices must be long, byte or bool tensors" + assert index.dtype in [ + torch.long, + torch.int8, + torch.bool, + ], "tensors used as indices must be long, byte or bool tensors" if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" for j in range(index.ndim): - assert index.shape[j] == self.shape[ - k + - j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + assert ( + index.shape[j] == self.shape[k + j] + ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" result.append(nonzero.select(1, j)) else: result.append(index) else: result.append(index) indices = result - assert len( - indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" + assert ( + len(indices) <= self.ndim + ), f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" # expand_outplace import torch._refs as refs diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py index b3ec98f0811f..503981409cca 100644 --- a/colossalai/_analyzer/_subclasses/_monkey_patch.py +++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py @@ -1,5 +1,4 @@ import torch -import torch.distributed as dist from packaging import version __all__ = [ @@ -48,7 +47,7 @@ "scatter", ] -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): aten = torch.ops.aten # TODO: dive deep here # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py index 59991dc50912..9d52c5593bb8 100644 --- a/colossalai/_analyzer/_subclasses/flop_tensor.py +++ b/colossalai/_analyzer/_subclasses/flop_tensor.py @@ -8,7 +8,7 @@ from enum import Enum, auto from functools import partial, reduce from numbers import Number -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Union import torch from packaging import version @@ -36,15 +36,15 @@ def _format_flops(flop): B = 1e9 T = 1e12 if flop < K: - return f'{flop:.2f}' + return f"{flop:.2f}" elif flop < M: - return f'{flop / K:.2f}K' + return f"{flop / K:.2f}K" elif flop < B: - return f'{flop / M:.2f}M' + return f"{flop / M:.2f}M" elif flop < T: - return f'{flop / B:.2f}B' + return f"{flop / B:.2f}B" else: - return f'{flop / T:.2f}T' + return f"{flop / T:.2f}T" def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number: @@ -59,11 +59,13 @@ def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: Returns: Number: The total number of floating point operations (FWD + BWD). """ - maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False) - or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_')) + maybe_inplace = ( + getattr(module, "inplace", False) + or kwargs.get("inplace", False) + or getattr(module, "__name__", None) in ("add_", "mul_", "div_", "sub_") + ) class DummyModule(torch.nn.Module): - def __init__(self, func): super().__init__() self.func = func @@ -74,21 +76,20 @@ def forward(self, *args, **kwargs): total_flop_count = {Phase.FWD: 0, Phase.BWD: 0} flop_counts = defaultdict(lambda: defaultdict(int)) - parents = ['Global'] + parents = ["Global"] module = module if isinstance(module, torch.nn.Module) else DummyModule(module) class FlopTensor(MetaTensor): _tensor: torch.Tensor def __repr__(self): - name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor' + name = "FlopParameter" if getattr(self, "_is_param", False) else "FlopTensor" if self.grad_fn: return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. rs = super().__torch_dispatch__(func, types, args, kwargs) @@ -115,9 +116,7 @@ def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() def create_backwards_push(name): - class PushState(torch.autograd.Function): - @staticmethod def forward(ctx, *args): args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) @@ -134,9 +133,7 @@ def backward(ctx, *grad_outs): return PushState.apply def create_backwards_pop(name): - class PopState(torch.autograd.Function): - @staticmethod def forward(ctx, *args): args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) @@ -147,14 +144,13 @@ def forward(ctx, *args): @staticmethod def backward(ctx, *grad_outs): nonlocal parents - assert (parents[-1] == name) + assert parents[-1] == name parents.pop() return grad_outs return PopState.apply def enter_module(name): - def f(module, inputs): nonlocal parents parents.append(name) @@ -165,10 +161,9 @@ def f(module, inputs): return f def exit_module(name): - def f(module, inputs, outputs): nonlocal parents - assert (parents[-1] == name) + assert parents[-1] == name parents.pop() outputs = normalize_tuple(outputs) return create_backwards_push(name)(*outputs) @@ -189,7 +184,7 @@ def display_flops(): for mod in flop_counts.keys(): print(f"Module: ", mod) for k, v in flop_counts[mod].items(): - print('\t', k, _format_flops(v)) + print("\t", k, _format_flops(v)) print() def detach_variables(r): @@ -201,7 +196,7 @@ def detach_variables(r): def wrap(r): if isinstance(r, torch.Tensor): - data_ptr_fn = getattr(r, '_tensor', r).data_ptr + data_ptr_fn = getattr(r, "_tensor", r).data_ptr r = FlopTensor(detach_variables(r)) if maybe_inplace: r = r + 0 @@ -375,8 +370,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: # Inputs[0] contains the shape of the input. input_shape = inputs[input_arg_index].shape - has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], - 'shape') else inputs[affine_arg_index] + has_affine = ( + inputs[affine_arg_index].shape is not None + if hasattr(inputs[affine_arg_index], "shape") + else inputs[affine_arg_index] + ) assert 2 <= len(input_shape) <= 5, input_shape # 5 is just a rough estimate flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) @@ -390,7 +388,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N training = inputs[-3] assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" if training: - return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore has_affine = inputs[1].shape is not None input_shape = reduce(operator.mul, inputs[0].shape) return input_shape * (2 if has_affine else 1) @@ -420,33 +418,30 @@ def ewise_flop(inputs: List[Any], outputs: List[Any]) -> Number: def zero_flop_jit(*args): """ - Count flops for zero flop layers. + Count flops for zero flop layers. """ return 0 -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): flop_mapping = { - # gemm + # gemm aten.mm.default: matmul_flop_jit, aten.matmul.default: matmul_flop_jit, aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, - - # convolution + # convolution aten.convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit, - - # normalization + # normalization aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit, aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), - - # pooling + # pooling aten.avg_pool1d.default: ewise_flop_counter(1, 0), aten.avg_pool2d.default: ewise_flop_counter(1, 0), aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1), @@ -469,7 +464,7 @@ def zero_flop_jit(*args): } ewise_flop_aten = [ - # basic op + # basic op aten.add.Tensor, aten.add_.Tensor, aten.div.Tensor, @@ -485,8 +480,7 @@ def zero_flop_jit(*args): aten.sum.default, aten.sum.dim_IntList, aten.mean.dim, - - # activation op + # activation op aten.hardswish.default, aten.hardswish_.default, aten.hardswish_backward.default, @@ -509,15 +503,12 @@ def zero_flop_jit(*args): aten.tanh.default, aten.tanh_backward.default, aten.threshold_backward.default, - - # dropout + # dropout aten.native_dropout.default, aten.native_dropout_backward.default, - - # distribution + # distribution aten.bernoulli_.float, - - # where + # where aten.where.self, ] for op in ewise_flop_aten: diff --git a/colossalai/_analyzer/_subclasses/meta_tensor.py b/colossalai/_analyzer/_subclasses/meta_tensor.py index 2bc212938ee0..8be97d01343e 100644 --- a/colossalai/_analyzer/_subclasses/meta_tensor.py +++ b/colossalai/_analyzer/_subclasses/meta_tensor.py @@ -3,12 +3,12 @@ import torch import torch.distributed as dist -from torch.types import _bool, _device, _dtype -from torch.utils._pytree import tree_flatten, tree_map +from torch.types import _device +from torch.utils._pytree import tree_map from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod -__all__ = ['MetaTensor', 'MetaTensorMode'] +__all__ = ["MetaTensor", "MetaTensorMode"] def register_storage(r, data_ptr_fn=None): @@ -28,8 +28,7 @@ def _normalize_tuple(x): # a hack of inplace execution in PyTorch def _assert_alias(func): - return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive - ) + return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive class MetaTensor(torch.Tensor): @@ -65,14 +64,15 @@ def __new__(cls, elem, device=None, data_ptr_fn=None): storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), - requires_grad=requires_grad) # deceive the frontend for aten selections + device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")), + requires_grad=requires_grad, + ) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: val = elem.data_ptr() data_ptr_fn = lambda: val - r._tensor = r._tensor.to(torch.device('meta')) + r._tensor = r._tensor.to(torch.device("meta")) # only tensor not on `meta` should be copied to `meta` register_storage(r._tensor, data_ptr_fn) @@ -81,7 +81,7 @@ def __new__(cls, elem, device=None, data_ptr_fn=None): return r def __repr__(self): - name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor' + name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor" if self.grad_fn: return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" @@ -97,15 +97,15 @@ def unwrap(x): x = x._tensor elif isinstance(x, torch.Tensor): device = x.device - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) - if 'device' in kwargs: - device = kwargs['device'] - kwargs['device'] = torch.device('meta') + if "device" in kwargs: + device = kwargs["device"] + kwargs["device"] = torch.device("meta") # run aten for backend=CPU but actually on backend=Meta # here we detect whether or not the execution generates a physical copy @@ -143,21 +143,21 @@ def replace(x): nonlocal device if isinstance(x, str) or isinstance(x, _device): device = x - return torch.device('meta') + return torch.device("meta") return x elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) return MetaTensor(elem, device=device) def cpu(self, *args, **kwargs): - if self.device.type == 'cpu': + if self.device.type == "cpu": return self.to(*args, **kwargs) - return self.to(*args, device='cpu', **kwargs) + return self.to(*args, device="cpu", **kwargs) def cuda(self, device=None, non_blocking=False): if device is not None: return self.to(device=device, non_blocking=non_blocking) - return self.to(device='cuda:0', non_blocking=non_blocking) + return self.to(device="cuda:0", non_blocking=non_blocking) def data_ptr(self): return self._tensor.data_ptr() @@ -177,19 +177,17 @@ class MetaTensorMode(object): """ def __init__(self): - self.torch_overrides = {} # override torch.xxx - self.dist_overrides = {} # override torch.distributed.xxx + self.torch_overrides = {} # override torch.xxx + self.dist_overrides = {} # override torch.distributed.xxx def __enter__(self): - def _dummy(*args, **kwargs): pass def _new(*args, orig_new=torch.empty, **kwargs): - return MetaTensor(orig_new(*args, **{ - **kwargs, 'device': 'meta' - }), - device=kwargs.get('device', torch.device('cpu'))) + return MetaTensor( + orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu")) + ) for func in _TorchOverrideableFactoryMethod: self.torch_overrides[func] = getattr(torch, func) diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py index 41d74f2e3719..cd244b22cac0 100644 --- a/colossalai/_analyzer/fx/codegen.py +++ b/colossalai/_analyzer/fx/codegen.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Any, Dict, List, Tuple import torch @@ -22,7 +22,7 @@ import colossalai from colossalai.fx._compatibility import compatibility -_register_custom_builtin('colossalai', 'import colossalai', colossalai) +_register_custom_builtin("colossalai", "import colossalai", colossalai) def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: @@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True): """ Generate the checkpoint function call code text """ - outputs = ', '.join(output_vars) - inputs = ', '.join(input_vars) - return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})' + outputs = ", ".join(output_vars) + inputs = ", ".join(input_vars) + return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})" def _end_of_ckpt(node: Node, ckpt_level: int) -> bool: """ Check if the node could end the ckpt region at `ckpt_level` """ - if len(node.meta['info'].activation_checkpoint) > ckpt_level: - return node.meta['info'].activation_checkpoint[ckpt_level] is not None + if len(node.meta["info"].activation_checkpoint) > ckpt_level: + return node.meta["info"].activation_checkpoint[ckpt_level] is not None return True @@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): current_region = None for idx, node in enumerate(node_list): - if len(node.meta['info'].activation_checkpoint) > ckpt_level: - act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level] + if len(node.meta["info"].activation_checkpoint) > ckpt_level: + act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): return ckpt_regions -def emit_ckpt_func(body, - ckpt_func, - node_list: List[Node], - emit_node_func, - delete_unused_value_func, - ckpt_level=0, - in_ckpt=False): +def emit_ckpt_func( + body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False +): """Emit ckpt function in nested way Args: @@ -156,12 +152,12 @@ def emit_ckpt_func(body, # label given by each layer, e.g. if you are currently at level (0, 1, 1) # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]]) + label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]]) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") # if there is more level to fetch - if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)): + if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)): ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] @@ -174,33 +170,40 @@ def emit_ckpt_func(body, break if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func, - ckpt_level + 1, True) + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func( + ckpt_func, + ckpt_func_buffer, + ckpt_node_list, + emit_node_func, + delete_unused_value_func, + ckpt_level + 1, + True, + ) node_idx += len(ckpt_node_list) else: node = node_list[node_idx] emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) node_idx += 1 - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") ckpt_func += ckpt_func_buffer # last level else: for node in node_list: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") - usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n' + usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n" if in_ckpt: - usage = ' ' + usage + usage = " " + usage body.append(usage) @@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # process ckpt_regions if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) node_idx += len(ckpt_node_list) @@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, @compatibility(is_backward_compatible=True) class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] @@ -251,7 +253,7 @@ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> Py wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -259,7 +261,7 @@ def add_global(name_hint: str, obj: Any): Graph, like functions or types. Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -281,16 +283,16 @@ def add_global(name_hint: str, obj: Any): def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): # This is a generic type, e.g. typing.List[torch.Tensor] origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, '__args__'): + if hasattr(o, "__args__"): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] @@ -309,19 +311,18 @@ def type_repr(o: Any): return add_global(typename, o) def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): + if isinstance(arg, tuple) and hasattr(arg, "_fields"): qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" return repr(arg) - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' + return f"{args_s}, {kwargs_s}" return args_s or kwargs_s # Run through reverse nodes and record the first instance of a use @@ -347,82 +348,94 @@ def delete_unused_values(user: Node, body): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods: + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" body.append(self.generate_output(node.args[0])) return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] @@ -432,13 +445,13 @@ def emit_node(node: Node, body): # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" if self._body_transformer: body = self._body_transformer(body) @@ -447,11 +460,11 @@ def emit_node(node: Node, body): add_global(name, value) prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = ''.join(ckpt_func) + prologue + prologue = "".join(ckpt_func) + prologue prologue = prologue - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" {wrap_stmts} {prologue} diff --git a/colossalai/_analyzer/fx/graph_module.py b/colossalai/_analyzer/fx/graph_module.py index 1fdedd758c01..9d3999e322b9 100644 --- a/colossalai/_analyzer/fx/graph_module.py +++ b/colossalai/_analyzer/fx/graph_module.py @@ -13,6 +13,7 @@ try: from torch.fx.graph import _PyTreeCodeGen + SUPPORT_PT_CODEGEN = True except ImportError: SUPPORT_PT_CODEGEN = False @@ -24,7 +25,6 @@ # This is a copy of torch.fx.graph_module._WrappedCall. # It should be removed when we stop supporting torch < 1.12.0. class _WrappedCall: - def __init__(self, cls, cls_call): self.cls = cls self.cls_call = cls_call @@ -50,12 +50,14 @@ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: # constituent substrings of the error message tb_repr = traceback.format_exc() - custom_msg = ("Call using an FX-traced Module, " - f"line {err_lineno} of the traced Module's " - "generated forward function:") - before_err = "".join(all_src_lines[err_lineno - 2:err_lineno]) + custom_msg = ( + "Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:" + ) + before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) marker = "~" * err_line_len + "~~~ <--- HERE" - err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2]) + err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) # joined message return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) @@ -65,11 +67,14 @@ def __call__(self, obj, *args, **kwargs): if self.cls_call is not None: return self.cls_call(obj, *args, **kwargs) else: - return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] except Exception as e: assert e.__traceback__ - topmost_framesummary: traceback.FrameSummary = \ - traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type] + topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract( + traceback.walk_tb(e.__traceback__) + )[ + -1 + ] # type: ignore[arg-type] if "eval_with_key" in topmost_framesummary.filename: print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr) raise e.with_traceback(None) @@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule): code. """ - def __init__(self, - root: Union[torch.nn.Module, Dict[str, Any]], - graph: torch.fx.Graph, - class_name: str = 'GraphModule'): + def __init__( + self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule" + ): super().__init__(root, graph, class_name) def bind(self, ckpt_def, globals): @@ -134,7 +138,7 @@ def recompile(self) -> PythonCode: if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module='self') + python_code = self._graph.python_code(root_module="self") self._code = python_code.src # To split ckpt functions code and forward code @@ -157,8 +161,8 @@ def recompile(self) -> PythonCode: # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. cls_call = cls.__call__ if "__call__" in vars(cls) else None - if '_wrapped_call' not in vars(cls): - cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] def call_wrapped(self, *args, **kwargs): return self._wrapped_call(self, *args, **kwargs) @@ -182,7 +186,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul """ folder = Path(folder) Path(folder).mkdir(exist_ok=True) - torch.save(self.state_dict(), folder / 'state_dict.pt') + torch.save(self.state_dict(), folder / "state_dict.pt") tab = " " * 4 # we add import colossalai here @@ -208,10 +212,10 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: for module_name, module in self.named_children(): module_str = _gen_model_repr(module_name, module) if module_str is None: - module_file = folder / f'{module_name}.pt' + module_file = folder / f"{module_name}.pt" torch.save(module, module_file) blobified_modules.append(module_name) - module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") module_str = f"torch.load(r'{module_file}') # {module_repr}" model_str += f"{tab*2}self.{module_name} = {module_str}\n" @@ -228,12 +232,14 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" model_str += f"{_addindent(self.code, 4)}\n" - module_file = folder / 'module.py' + module_file = folder / "module.py" module_file.write_text(model_str) - init_file = folder / '__init__.py' - init_file.write_text('from .module import *') + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") if len(blobified_modules) > 0: - warnings.warn("Was not able to save the following children modules as reprs -" - f"saved as pickled files instead: {blobified_modules}") + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py index fbe8400a437e..d2671787ea63 100644 --- a/colossalai/_analyzer/fx/node_util.py +++ b/colossalai/_analyzer/fx/node_util.py @@ -1,9 +1,9 @@ from dataclasses import dataclass, field -from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch -from torch.autograd.profiler_util import _format_memory, _format_time -from torch.fx import Graph, GraphModule, Node +from torch.autograd.profiler_util import _format_memory +from torch.fx import Node from colossalai._analyzer.envs import MeshConfig @@ -85,12 +85,12 @@ class MetaInfo: node: Node # directory - mod_dir: str = '' + mod_dir: str = "" # ctx[data_ptr] = Tensor # mark the storage for ctx.save_for_backward - global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared - curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node + global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared + curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node # should be updated after each graph manipulation # ============================== Update ==================================== @@ -100,7 +100,7 @@ class MetaInfo: inputs: Tuple[torch.Tensor] = () outputs: Tuple[torch.Tensor] = () - is_alias: Tuple[bool] = () # whether the output is an alias of input + is_alias: Tuple[bool] = () # whether the output is an alias of input # compute cost fwd_flop: Optional[int] = 0 @@ -112,29 +112,29 @@ class MetaInfo: # should keep the same whenever manipulated # ============================= Invariant ================================== - activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen + activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen to_offload: Optional[bool] = False - sharding_spec: str = 'RR' + sharding_spec: str = "RR" def __new__(cls, node: Node, **kwargs): orig_init = cls.__init__ # if initialized, return the existing one # should disable the __init__ function - if node.meta.get('info', None) is not None: + if node.meta.get("info", None) is not None: def _dummy(self, *args, **kwargs): - if getattr(self, '_is_init', False): + if getattr(self, "_is_init", False): self._is_init = True orig_init(self, *args, **kwargs) cls.__init__ = orig_init cls.__init__ = _dummy - return node.meta['info'] + return node.meta["info"] return super().__new__(cls) def __post_init__(self): - self.node.meta['info'] = self + self.node.meta["info"] = self @property def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH): @@ -188,24 +188,26 @@ def backward_size(self): return compute_size_in_bytes(self.inputs) def __repr__(self): - s = f'Node {self.node.name}' + s = f"Node {self.node.name}" if self.parameters: - s += f'\n\thas parameter of size {_format_memory(self.param_size)}' + s += f"\n\thas parameter of size {_format_memory(self.param_size)}" if self.buffers: - s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}' + s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}" if self.output_size: - s += f'\n\thas output activation of size {_format_memory(self.output_size)}' + s += f"\n\thas output activation of size {_format_memory(self.output_size)}" # if self.total_size: # s += f'\n\thas total activation of size {_format_memory(self.total_size)}' if self.temp_size: - s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}' + s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}" if self.backward_size: - s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}' - s += f'\n\tfwd_flop = {self.fwd_flop}'\ - f'\n\tbwd_flop = {self.bwd_flop}'\ - f'\n\tfwd_comm = {self.fwd_comm}'\ - f'\n\tbwd_comm = {self.bwd_comm}'\ - f'\n\tto_recompute = {self.to_recompute}'\ - f'\n\tto_offload = {self.to_offload}'\ - f'\n\tsharding_spec = {self.sharding_spec}' + s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}" + s += ( + f"\n\tfwd_flop = {self.fwd_flop}" + f"\n\tbwd_flop = {self.bwd_flop}" + f"\n\tfwd_comm = {self.fwd_comm}" + f"\n\tbwd_comm = {self.bwd_comm}" + f"\n\tto_recompute = {self.to_recompute}" + f"\n\tto_offload = {self.to_offload}" + f"\n\tsharding_spec = {self.sharding_spec}" + ) return s diff --git a/colossalai/_analyzer/fx/passes/graph_profile.py b/colossalai/_analyzer/fx/passes/graph_profile.py index c3e760b31e96..158ebce219cd 100644 --- a/colossalai/_analyzer/fx/passes/graph_profile.py +++ b/colossalai/_analyzer/fx/passes/graph_profile.py @@ -1,8 +1,8 @@ -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch import torch.fx -from torch.autograd.profiler_util import _format_memory, _format_time +from torch.autograd.profiler_util import _format_memory from torch.fx import GraphModule from torch.fx.node import Argument, Node, Target @@ -13,14 +13,14 @@ def _format_flops(flops: float) -> str: """Returns a formatted FLOP size string""" if flops > 1e12: - return f'{flops / 1e12:.2f} TFLOPs' + return f"{flops / 1e12:.2f} TFLOPs" elif flops > 1e9: - return f'{flops / 1e9:.2f} GFLOPs' + return f"{flops / 1e9:.2f} GFLOPs" elif flops > 1e6: - return f'{flops / 1e6:.2f} MFLOPs' + return f"{flops / 1e6:.2f} MFLOPs" elif flops > 1e3: - return f'{flops / 1e3:.2f} kFLOPs' - return f'{flops} FLOPs' + return f"{flops / 1e3:.2f} kFLOPs" + return f"{flops} FLOPs" def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]: @@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter): Fetch shape argument from ``ShapeProp`` without re-executing the ``GraphModule`` from scratch. """ + _profileable = [ - 'call_function', - 'call_module', - 'call_method', + "call_function", + "call_module", + "call_method", ] def __init__(self, module: GraphModule, garbage_collect_values: bool = True): @@ -77,14 +78,13 @@ def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_pr self.args_iter: Iterator[Any] = iter(args) for node in self.module.graph.nodes: - - self.run_node(node) # No need to store. + self.run_node(node) # No need to store. if self.garbage_collect_values: for to_delete in self.user_to_last_uses.get(node, []): del self.env[to_delete] - if node.op == 'output': + if node.op == "output": output_val = self.env[node] return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val @@ -133,9 +133,11 @@ def summary(self) -> str: try: from tabulate import tabulate except ImportError: - print("`summary` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) # Build up a list of summary information for each node node_summaries: List[List[Any]] = [] @@ -145,36 +147,38 @@ def summary(self) -> str: node: Node n_info = MetaInfo(node) last_n_info = last_n_info or n_info - node_summaries.append([ - node.op, - str(node), - _format_memory(n_info.accumulate_size), - _format_memory(n_info.accumulate_size - last_n_info.accumulate_size), - _format_memory(n_info.output_size), - _format_memory(n_info.temp_size), - _format_memory(n_info.param_size), - _format_memory(n_info.backward_size), - _format_flops(n_info.fwd_flop), - _format_flops(n_info.bwd_flop), - ]) + node_summaries.append( + [ + node.op, + str(node), + _format_memory(n_info.accumulate_size), + _format_memory(n_info.accumulate_size - last_n_info.accumulate_size), + _format_memory(n_info.output_size), + _format_memory(n_info.temp_size), + _format_memory(n_info.param_size), + _format_memory(n_info.backward_size), + _format_flops(n_info.fwd_flop), + _format_flops(n_info.bwd_flop), + ] + ) last_n_info = n_info # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers: List[str] = [ - 'Op type', - 'Op', - 'Accumulate size', - 'Incremental size', - 'Output size', - 'Temp size', - 'Param size', - 'Backward size', - 'Fwd FLOPs', - 'Bwd FLOPs', + "Op type", + "Op", + "Accumulate size", + "Incremental size", + "Output size", + "Temp size", + "Param size", + "Backward size", + "Fwd FLOPs", + "Bwd FLOPs", ] - return tabulate(node_summaries, headers=headers, stralign='right') + return tabulate(node_summaries, headers=headers, stralign="right") class CommunicationProfiler(GraphProfiler): @@ -222,6 +226,7 @@ class with the ``@register_flop_count_impl`` decorator: >>> def my_fn_flop_count_impl(*args, **kwargs): >>> return 0, 0 """ + _custom_flop_count_impl = {} def run_node(self, n: torch.fx.Node) -> Any: @@ -246,11 +251,13 @@ def run_node(self, n: torch.fx.Node) -> Any: ( n_info.fwd_flop, n_info.bwd_flop, - ) = getattr(self, n.op)(n.target, args, kwargs) + ) = getattr( + self, n.op + )(n.target, args, kwargs) except Exception as e: raise RuntimeError( - f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. ' - f'Please refer to function\'s docstring to register the relevant profile_impl for this node!' + f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. " + f"Please refer to function's docstring to register the relevant profile_impl for this node!" ) from e # retain the autograd graph @@ -259,7 +266,7 @@ def run_node(self, n: torch.fx.Node) -> Any: return _denormalize_tuple(n_info.outputs) - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node and return the profiling result. Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be @@ -283,7 +290,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di else: return flop_count(target, *args, **kwargs) - def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node and return the profiling result. @@ -301,7 +308,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict assert isinstance(target, str) return flop_count(getattr(torch.Tensor, target), *args, **kwargs) - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_module`` node and return the profiling result. @@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule Returns: GraphModule: The same GraphModule with profiling information """ - for profiler_cls in (FlopProfiler, - # CommunicationProfiler, # TODO: add communication profiling - ): + for profiler_cls in ( + FlopProfiler, + # CommunicationProfiler, # TODO: add communication profiling + ): profiler = profiler_cls(module) profiler.propagate(*args, device=_current_device(module)) diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py index 23e83013e02f..8d44f1d4b59d 100644 --- a/colossalai/_analyzer/fx/passes/shape_prop.py +++ b/colossalai/_analyzer/fx/passes/shape_prop.py @@ -54,7 +54,7 @@ def _current_device(module): try: return next(module.parameters()).device except StopIteration: - return torch.device('cpu') + return torch.device("cpu") @compatibility(is_backward_compatible=False) @@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter): >>> # do something here >>> return torch.empty(output_shape, device=output_device) """ + _custom_dispatch_func = {} _mode = MetaTensorMode() @@ -115,15 +116,14 @@ def run_node(self, n: torch.fx.Node) -> Any: r = getattr(self, n.op)(n.target, args, kwargs) def unwrap_fn(elem): - def _convert_meta(t: torch.Tensor): - if t.device == 'meta': + if t.device == "meta": return t else: - return t.to('meta') + return t.to("meta") if isinstance(elem, MetaTensor): - if getattr(self, '_is_param', False): + if getattr(self, "_is_param", False): return torch.nn.Parameter(_convert_meta(elem._tensor)) return _convert_meta(elem._tensor) @@ -139,21 +139,24 @@ def _convert_meta(t: torch.Tensor): n_info = MetaInfo(n) n_info.outputs = _normalize_tuple(r) - if n.op == 'call_module': + if n.op == "call_module": submod = self.fetch_attr(n.target) n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()}) n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()}) else: - n_info.parameters.update({ - k.name: MetaTensor(v) - for k, v in zip(n.args, args) - if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter) - }) + n_info.parameters.update( + { + k.name: MetaTensor(v) + for k, v in zip(n.args, args) + if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter) + } + ) n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)}) - n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \ - tuple(v for v in kwargs.values() if is_pure_tensor(v)) + n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple( + v for v in kwargs.values() if is_pure_tensor(v) + ) # align with SPMD if isinstance(r, (tuple, list)): @@ -168,7 +171,7 @@ def _convert_meta(t: torch.Tensor): n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs)) return r - def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node and return the result. If the target of ``Node`` is registered with ``@register_shape_impl``, @@ -197,7 +200,7 @@ def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[st else: return res - def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node and return the result. @@ -218,7 +221,8 @@ def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, convert_to_parameter = False if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance( - args[0], torch.nn.parameter.Parameter): + args[0], torch.nn.parameter.Parameter + ): convert_to_parameter = True # Execute the method and return the result assert isinstance(target, str) diff --git a/colossalai/_analyzer/fx/symbolic_profile.py b/colossalai/_analyzer/fx/symbolic_profile.py index dd7f22c6c98a..5732a6665f78 100644 --- a/colossalai/_analyzer/fx/symbolic_profile.py +++ b/colossalai/_analyzer/fx/symbolic_profile.py @@ -1,5 +1,3 @@ -import torch -import torch.fx from torch.fx import GraphModule from .passes import ShapeProp, graph_profile_pass, shape_prop_pass @@ -7,7 +5,6 @@ def register_flop_count_impl(func): - def wrapper(impl): FlopProfiler._custom_flop_count_impl[func] = impl return impl @@ -16,7 +13,6 @@ def wrapper(impl): def register_shape_impl(func): - def wrapper(impl): ShapeProp._custom_dispatch_func[func] = impl return impl diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py index 1e75b47ca5b0..b8b83282b42c 100644 --- a/colossalai/_analyzer/fx/tracer/bias_addition.py +++ b/colossalai/_analyzer/fx/tracer/bias_addition.py @@ -12,7 +12,7 @@ __all__ = [] -@register_tracer_impl(F.linear, name='_bias_addition_impl') +@register_tracer_impl(F.linear, name="_bias_addition_impl") def linear_impl(input, weight, bias=None): if bias is None: return F.linear(input, weight) @@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None): return F.linear(input, weight) + bias -@register_tracer_impl(F.conv1d, name='_bias_addition_impl') +@register_tracer_impl(F.conv1d, name="_bias_addition_impl") def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1): if bias is None: return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( - (-1, 1)) + (-1, 1) + ) -@register_tracer_impl(F.conv2d, name='_bias_addition_impl') +@register_tracer_impl(F.conv2d, name="_bias_addition_impl") def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1): if bias is None: return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( - (-1, 1, 1)) + (-1, 1, 1) + ) -@register_tracer_impl(F.conv3d, name='_bias_addition_impl') +@register_tracer_impl(F.conv3d, name="_bias_addition_impl") def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1): if bias is None: return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( - (-1, 1, 1, 1)) - - -@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl') -def conv_transpose1d_impl(input, - weight, - bias=None, - stride=_single(1), - padding=_single(0), - output_padding=_single(0), - groups=1, - dilation=_single(1)): + (-1, 1, 1, 1) + ) + + +@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl") +def conv_transpose1d_impl( + input, + weight, + bias=None, + stride=_single(1), + padding=_single(0), + output_padding=_single(0), + groups=1, + dilation=_single(1), +): if bias is None: - return F.conv_transpose1d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + return F.conv_transpose1d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) else: - return F.conv_transpose1d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + bias.reshape((-1, 1)) - - -@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl') -def conv_transpose2d_impl(input, - weight, - bias=None, - stride=_pair(1), - padding=_pair(0), - output_padding=_pair(0), - groups=1, - dilation=_pair(1)): + return F.conv_transpose1d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + bias.reshape((-1, 1)) + + +@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl") +def conv_transpose2d_impl( + input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1) +): if bias is None: - return F.conv_transpose2d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + return F.conv_transpose2d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) else: - return F.conv_transpose2d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + bias.reshape((-1, 1, 1)) - - -@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl') -def conv_transpose3d_impl(input, - weight, - bias=None, - stride=_triple(1), - padding=_triple(0), - output_padding=_triple(0), - groups=1, - dilation=_triple(1)): + return F.conv_transpose2d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + bias.reshape((-1, 1, 1)) + + +@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl") +def conv_transpose3d_impl( + input, + weight, + bias=None, + stride=_triple(1), + padding=_triple(0), + output_padding=_triple(0), + groups=1, + dilation=_triple(1), +): if bias is None: - return F.conv_transpose3d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + return F.conv_transpose3d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) else: - return F.conv_transpose3d(input, - weight, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation) + bias.reshape((-1, 1, 1, 1)) - - -@register_tracer_impl(torch.addmm, name='_bias_addition_impl') -@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl') + return F.conv_transpose3d( + input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ) + bias.reshape((-1, 1, 1, 1)) + + +@register_tracer_impl(torch.addmm, name="_bias_addition_impl") +@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl") def addmm_impl(input, mat1, mat2, beta=1, alpha=1): if alpha != 1 and beta != 1: return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta @@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1): return F.linear(mat1, mat2.transpose(0, 1)) + input -@register_tracer_impl(torch.addbmm, name='_bias_addition_impl') -@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl') +@register_tracer_impl(torch.addbmm, name="_bias_addition_impl") +@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl") def addbmm_impl(input, batch1, batch2, beta=1, alpha=1): if alpha != 1 and beta != 1: return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta diff --git a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py index 112c7c9637d2..ff6b55be5117 100644 --- a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py +++ b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py @@ -4,6 +4,7 @@ try: import apex + register_leaf_module(apex.normalization.FusedLayerNorm) register_leaf_module(apex.normalization.FusedRMSNorm) register_leaf_module(apex.normalization.MixedFusedLayerNorm) diff --git a/colossalai/_analyzer/fx/tracer/proxy.py b/colossalai/_analyzer/fx/tracer/proxy.py index ce379efdcf0d..e3e210e7d190 100644 --- a/colossalai/_analyzer/fx/tracer/proxy.py +++ b/colossalai/_analyzer/fx/tracer/proxy.py @@ -1,10 +1,8 @@ import operator -from typing import Any, Callable, Dict, Optional, Set, Union +from typing import Any, Callable, Dict, Optional, Union import torch -import torch.nn as nn -from torch.fx import Graph, Node, Proxy, Tracer -from torch.fx.graph import _Namespace +from torch.fx import Node, Proxy from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor @@ -32,7 +30,7 @@ def meta_data(self, args): def __torch_function__(cls, orig_method, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs if orig_method in cls._func_dispatch: - impl = cls._func_dispatch.pop(orig_method) # avoid recursion + impl = cls._func_dispatch.pop(orig_method) # avoid recursion proxy = impl(*args, **kwargs) cls._func_dispatch[orig_method] = impl return proxy @@ -72,7 +70,7 @@ def __getattr__(self, k): return ColoAttribute(self, k, getattr(self._meta_data, k, None)) def __setitem__(self, key, value): - proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {}) proxy.meta_data = self._meta_data return proxy @@ -89,7 +87,6 @@ def __isinstancecheck__(self, type): class ColoAttribute(ColoProxy): - def __init__(self, root, attr: str, data=None): self.root = root self.attr = attr @@ -102,11 +99,11 @@ def node(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) def __repr__(self): return f"ColoAttribute({self.node.name}, attr={self.attr})" diff --git a/colossalai/_analyzer/fx/tracer/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/symbolic_trace.py index 2018863f6f5f..7884fd911c86 100644 --- a/colossalai/_analyzer/fx/tracer/symbolic_trace.py +++ b/colossalai/_analyzer/fx/tracer/symbolic_trace.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Union import torch from torch.fx import Tracer @@ -8,6 +8,7 @@ try: from ..codegen import ActivationCheckpointCodeGen + SUPPORT_ACTIVATION = True except: SUPPORT_ACTIVATION = False @@ -16,7 +17,7 @@ def _default_device(): - return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") def _current_device(module: torch.nn.Module): @@ -144,10 +145,9 @@ def forward(self, x): if meta_args: device, orig_device = _default_device(), _current_device(root) wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem - graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, - bias_addition_split=bias_addition_split).trace(root.to(device), - concrete_args=concrete_args, - meta_args=tree_map(wrap_fn, meta_args)) + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace( + root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args) + ) if trace_act_ckpt and SUPPORT_ACTIVATION: graph.set_codegen(ActivationCheckpointCodeGen()) root.to(orig_device) diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py index 6958a00a6a72..17dce767269d 100644 --- a/colossalai/_analyzer/fx/tracer/tracer.py +++ b/colossalai/_analyzer/fx/tracer/tracer.py @@ -20,11 +20,10 @@ def _truncate_suffix(s: str): import re # FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name - return re.sub(r'_\d+$', '', s) + return re.sub(r"_\d+$", "", s) -def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'): - +def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"): def wrapper(impl): assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}" getattr(ColoTracer, name)[func] = impl @@ -34,7 +33,6 @@ def wrapper(impl): def register_leaf_module_impl(module: nn.Module): - def wrapper(impl): ColoTracer._custom_leaf_module_impl[module] = impl return impl @@ -76,7 +74,7 @@ def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = Fal self.ckpt_regions = [] self.ckpt_idx = 0 - self.mod_dir = '' + self.mod_dir = "" # whether the tracer should split the bias_add ops into two ops self.bias_addition_split = bias_addition_split @@ -87,35 +85,41 @@ def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None: return False # user can specify which modules are leaf modules and which are not - return (type(m) not in self._custom_non_leaf_module - and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name))) + return type(m) not in self._custom_non_leaf_module and ( + type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name) + ) - def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], - kwargs: Dict[str, Any]) -> Any: + def call_module( + self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: curr_dir = self.mod_dir - self.mod_dir = 'self.' + self.path_of_module(m) + self.mod_dir = "self." + self.path_of_module(m) rst = super().call_module(m, forward, args, kwargs) self.mod_dir = curr_dir return rst - def proxy(self, node: Node) -> 'ColoProxy': + def proxy(self, node: Node) -> "ColoProxy": return ColoProxy(node, self) - def create_proxy(self, - kind: str, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - + def create_proxy( + self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], "Proxy"] = None, + ): proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p - if kind == 'placeholder': - proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( - _truncate_suffix(target), None) - elif kind == 'get_attr': + if kind == "placeholder": + proxy.meta_data = ( + self.meta_args[target] + if target in self.meta_args + else self.concrete_args.get(_truncate_suffix(target), None) + ) + elif kind == "get_attr": self.disable_module_getattr = True try: attr_itr = self.root @@ -125,20 +129,21 @@ def create_proxy(self, proxy.meta_data = attr_itr finally: self.disable_module_getattr = False - elif kind == 'call_function': + elif kind == "call_function": proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_method': + elif kind == "call_method": self.disable_module_getattr = True try: - if target == '__call__': + if target == "__call__": proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) + proxy._meta_data = getattr(unwrap_fn(args[0]), target)( + *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs) + ) finally: self.disable_module_getattr = False - elif kind == 'call_module': + elif kind == "call_module": mod = self.root.get_submodule(target) self.disable_module_getattr = True try: @@ -158,11 +163,12 @@ def create_node(self, *args, **kwargs) -> Node: n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions)) return node - def trace(self, - root: torch.nn.Module, - concrete_args: Optional[Dict[str, torch.Tensor]] = None, - meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: - + def trace( + self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None, + ) -> Graph: if meta_args is None: meta_args = {} @@ -177,9 +183,7 @@ def trace(self, non_concrete_arg_names = sig_names - concrete_arg_names # update concrete args with default values for k, v in sig.parameters.items(): - if k in sig_names - meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default def _check_arg_name_valid(names: Iterable[str]): @@ -194,9 +198,9 @@ def _check_arg_name_valid(names: Iterable[str]): self.meta_args = meta_args with self._torch_factory_override(), self._tracer_override(), torch.no_grad(): - self.mod_dir = 'self' + self.mod_dir = "self" self.graph = super().trace(root, concrete_args=concrete_args) - self.mod_dir = '' + self.mod_dir = "" self.graph.lint() for node in self.graph.nodes: @@ -266,17 +270,17 @@ def _torch_factory_override(self): # override the torch factory functions to create a proxy when the method # is called during ``symbolic_trace()``. def wrap_factory_method(target): - @functools.wraps(target) def wrapper(*args, **kwargs): is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( - isinstance(p, ColoProxy) for p in kwargs.values()) + isinstance(p, ColoProxy) for p in kwargs.values() + ) if is_proxy: # if the arg is a proxy, then need to record this function called on this proxy # e.g. torch.ones(size) where size is an input proxy self.disable_module_getattr = True try: - proxy = self.create_proxy('call_function', target, args, kwargs) + proxy = self.create_proxy("call_function", target, args, kwargs) finally: self.disable_module_getattr = False return proxy @@ -341,10 +345,13 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac if attr_val is p: if n not in parameter_proxy_cache: kwargs = {} - if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: - kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else - lambda node: ColoProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ColoProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None @@ -355,8 +362,9 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac return maybe_buffer_proxy if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), - parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy diff --git a/colossalai/amp/naive_amp/grad_scaler/__init__.py b/colossalai/amp/naive_amp/grad_scaler/__init__.py index dc8499d877e1..34a20e8d67d6 100644 --- a/colossalai/amp/naive_amp/grad_scaler/__init__.py +++ b/colossalai/amp/naive_amp/grad_scaler/__init__.py @@ -2,4 +2,4 @@ from .constant_grad_scaler import ConstantGradScaler from .dynamic_grad_scaler import DynamicGradScaler -__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler'] +__all__ = ["BaseGradScaler", "ConstantGradScaler", "DynamicGradScaler"] diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index 0d84384a7f67..79661a44424f 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger -__all__ = ['BaseGradScaler'] +__all__ = ["BaseGradScaler"] class BaseGradScaler(ABC): @@ -30,24 +30,21 @@ def __init__(self, initial_scale: float, verbose: bool): @property def scale(self) -> Tensor: - """Returns the loss scale. - """ + """Returns the loss scale.""" return self._scale @property def inv_scale(self) -> Tensor: - """Returns the inverse of the loss scale. - """ + """Returns the inverse of the loss scale.""" return self._scale.double().reciprocal().float() def state_dict(self) -> Dict: - """Returns the states of the gradient scaler as a dict object. - """ + """Returns the states of the gradient scaler as a dict object.""" state_dict = dict() - state_dict['scale'] = self.scale + state_dict["scale"] = self.scale return state_dict def load_state_dict(self, state_dict: Dict) -> None: @@ -57,7 +54,7 @@ def load_state_dict(self, state_dict: Dict) -> None: state_dict (dict): the states of the gradient scaler """ - self._scale = state_dict['scale'] + self._scale = state_dict["scale"] @abstractmethod def update(self, overflow: bool) -> None: @@ -67,8 +64,6 @@ def update(self, overflow: bool) -> None: overflow (bool): whether overflow occurs """ - pass - def log(self, message, *args, **kwargs): """Log messages. diff --git a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py index a2f518c5dd28..2ad8b51ac22c 100644 --- a/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/constant_grad_scaler.py @@ -2,7 +2,7 @@ # -*- encoding: utf-8 -*- from .base_grad_scaler import BaseGradScaler -__all__ = ['ConstantGradScaler'] +__all__ = ["ConstantGradScaler"] class ConstantGradScaler(BaseGradScaler): @@ -23,4 +23,3 @@ def update(self, overflow: bool) -> None: Args: overflow (bool): whether overflow occurs """ - pass diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index e899b9ca4c89..65133a4b3712 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -7,7 +7,7 @@ from .base_grad_scaler import BaseGradScaler -__all__ = ['DynamicGradScaler'] +__all__ = ["DynamicGradScaler"] class DynamicGradScaler(BaseGradScaler): @@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler): verbose (bool): whether to log messages, defaults to False """ - def __init__(self, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - min_scale: Optional[float] = None, - max_scale: Optional[float] = None, - hysteresis: int = 2, - verbose: bool = False): + def __init__( + self, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + min_scale: Optional[float] = None, + max_scale: Optional[float] = None, + hysteresis: int = 2, + verbose: bool = False, + ): super().__init__(initial_scale, verbose) if min_scale: self._min_scale = torch.cuda.FloatTensor([min_scale]) @@ -53,18 +55,17 @@ def __init__(self, self._sanity_checks() def _sanity_checks(self) -> None: - """Check if the arguments are correct. - """ + """Check if the arguments are correct.""" if self._min_scale: - assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' - assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale' + assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative" + assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale" if self._max_scale: - assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative' - assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale' - assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' - assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1' - assert self._hysteresis >= 0, 'The hysteresis cannot be negative' + assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative" + assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale" + assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1" + assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1" + assert self._hysteresis >= 0, "The hysteresis cannot be negative" def update(self, overflow: bool) -> None: """Update the loss scale. @@ -88,19 +89,18 @@ def update(self, overflow: bool) -> None: self.log( f"No overflow for consecutive {self._growth_interval} steps, " f"the loss scale is adjusted to {self.scale.item()}", - ranks=[0]) + ranks=[0], + ) def _backoff_scale(self) -> None: - """Decrease the loss scale - """ + """Decrease the loss scale""" self._scale = self._scale * self._backoff_factor if self._min_scale: self._scale = torch.max(self._scale, self._min_scale) def _grow_scale(self) -> None: - """Increase the loss scale - """ + """Increase the loss scale""" self._scale = self._scale * self._growth_factor if self._max_scale: @@ -108,14 +108,14 @@ def _grow_scale(self) -> None: def state_dict(self): state_dict = dict() - state_dict['scale'] = self._scale - state_dict['growth_factor'] = self._growth_factor - state_dict['backoff_factor'] = self._backoff_factor - state_dict['hysteresis'] = self._hysteresis + state_dict["scale"] = self._scale + state_dict["growth_factor"] = self._growth_factor + state_dict["backoff_factor"] = self._backoff_factor + state_dict["hysteresis"] = self._hysteresis return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) - self._growth_factor = state_dict['growth_factor'] - self._backoff_factor = state_dict['backoff_factor'] - self._hysteresis = state_dict['hysteresis'] + self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._hysteresis = state_dict["hysteresis"] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py index b0348e1477bb..a31811e4a567 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py @@ -3,7 +3,7 @@ from .fp16 import FP16MixedPrecisionMixin __all__ = [ - 'MixedPrecisionMixin', - 'FP16MixedPrecisionMixin', - 'BF16MixedPrecisionMixin', + "MixedPrecisionMixin", + "FP16MixedPrecisionMixin", + "BF16MixedPrecisionMixin", ] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py index a52a9747ad1e..fc7e0b74179a 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -39,6 +39,7 @@ def zero_grad(self): return self.optim.zero_grad() ``` """ + dtype: torch.dtype @abstractmethod @@ -51,7 +52,6 @@ def pre_backward(self, loss: Tensor) -> Tensor: Returns: Tensor: Loss value (possibly scaled). """ - pass @abstractmethod def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: @@ -64,7 +64,6 @@ def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: Returns: Tensor: Gradient of the tensor (possibly scaled). """ - pass @abstractmethod def should_skip_step(self) -> bool: @@ -73,13 +72,10 @@ def should_skip_step(self) -> bool: Returns: bool: Whether to skip the step. """ - pass @abstractmethod def pre_zero_grad(self) -> None: - """Called before zero_grad. - """ - pass + """Called before zero_grad.""" @abstractmethod def get_grad_div_scale(self) -> float: @@ -88,4 +84,3 @@ def get_grad_div_scale(self) -> float: Returns: float: A divisor for gradient clipping or step. """ - pass diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py index 1ce8e42eb3ed..9ce272356797 100644 --- a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -19,22 +19,26 @@ class OptimState(Enum): class FP16MixedPrecisionMixin(MixedPrecisionMixin): dtype = torch.float16 - def __init__(self, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32) -> None: + def __init__( + self, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: super().__init__() - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) self.optim_state = OptimState.UNSCALED self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) @@ -49,7 +53,6 @@ def check_local_overflow(self) -> bool: Returns: bool: Whether there is overflow in the local process. """ - pass def check_overflow(self) -> bool: # clear previous overflow record @@ -79,6 +82,6 @@ def pre_zero_grad(self) -> None: pass def get_grad_div_scale(self) -> float: - assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping' + assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping" self.optim_state = OptimState.UNSCALED return self.loss_scale diff --git a/colossalai/amp/naive_amp/mixed_precision_optimizer.py b/colossalai/amp/naive_amp/mixed_precision_optimizer.py index 626a00c96d04..6a192cc5cb83 100644 --- a/colossalai/amp/naive_amp/mixed_precision_optimizer.py +++ b/colossalai/amp/naive_amp/mixed_precision_optimizer.py @@ -11,18 +11,20 @@ class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): - - def __init__(self, - working_params: List[Parameter], - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32) -> None: - super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, - max_scale) + def __init__( + self, + working_params: List[Parameter], + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) self.params = working_params def check_local_overflow(self) -> bool: @@ -33,38 +35,41 @@ def check_local_overflow(self) -> bool: class MixedPrecisionOptimizer(OptimizerWrapper): - - def __init__(self, - optim: Optimizer, - precision: str = 'fp16', - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0.0): + def __init__( + self, + optim: Optimizer, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + ): super().__init__(optim) - if precision == 'fp16': + if precision == "fp16": working_params = [] for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: working_params.append(p) - self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params, - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) - elif precision == 'bf16': + self.mixed_precision = NaiveFP16MixedPrecisionMixin( + working_params, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + elif precision == "bf16": self.mixed_precision = BF16MixedPrecisionMixin() else: - raise ValueError(f'Unsupported precision: {precision}') + raise ValueError(f"Unsupported precision: {precision}") if max_norm > 0.0: - raise NotImplementedError('max_norm is not supported yet.') + raise NotImplementedError("max_norm is not supported yet.") self.max_norm = max_norm self.working_to_master_map: Dict[Parameter, Tensor] = {} self.master_to_working_map: Dict[Tensor, Parameter] = {} @@ -72,7 +77,7 @@ def __init__(self, # create master weights for group in self.optim.param_groups: master_params = [] - for p in group['params']: + for p in group["params"]: if p.requires_grad: master_p = p if p.dtype != torch.float: @@ -80,7 +85,7 @@ def __init__(self, self.working_to_master_map[p] = master_p self.master_to_working_map[master_p] = p master_params.append(master_p) - group['params'] = master_params + group["params"] = master_params def backward(self, loss: Tensor, *args, **kwargs): loss = self.mixed_precision.pre_backward(loss) @@ -101,24 +106,24 @@ def _unscale_and_clip_grads(self, total_norm: float) -> None: if self.mixed_precision is not None: div_scale = self.mixed_precision.get_grad_div_scale() - if self.max_norm > 0.: + if self.max_norm > 0.0: # norm is in fact norm*scale clip = ((total_norm / div_scale) + 1e-6) / self.max_norm if clip > 1: div_scale = clip * div_scale for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue - p.grad.data.mul_(1. / div_scale) + p.grad.data.mul_(1.0 / div_scale) def _compute_grad_norm(self) -> float: - if self.max_norm <= 0.: - return 0. - grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None] + if self.max_norm <= 0.0: + return 0.0 + grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None] if len(grads) == 0: - return 0. + return 0.0 device = grads[0].device # TODO(ver217): support tp total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) @@ -130,7 +135,7 @@ def step(self, *args, **kwargs): return # prepare grads for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: working_param = self.master_to_working_map[p] if p is working_param: continue @@ -142,7 +147,7 @@ def step(self, *args, **kwargs): self.optim.step(*args, **kwargs) # update working params for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: working_param = self.master_to_working_map[p] if p is working_param: continue diff --git a/colossalai/auto_parallel/checkpoint/build_c_ext.py b/colossalai/auto_parallel/checkpoint/build_c_ext.py index af4349865a7b..7de56f80525a 100644 --- a/colossalai/auto_parallel/checkpoint/build_c_ext.py +++ b/colossalai/auto_parallel/checkpoint/build_c_ext.py @@ -3,14 +3,16 @@ from setuptools import Extension, setup this_dir = os.path.dirname(os.path.abspath(__file__)) -ext_modules = [Extension( - 'rotorc', - sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')], -)] +ext_modules = [ + Extension( + "rotorc", + sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")], + ) +] setup( - name='rotor c extension', - version='0.1', - description='rotor c extension for faster dp computing', + name="rotor c extension", + version="0.1", + description="rotor c extension for faster dp computing", ext_modules=ext_modules, ) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index b388d00ac553..8aaa690b333c 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -12,13 +12,13 @@ ) from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen -__all___ = ['CheckpointSolverBase'] +__all___ = ["CheckpointSolverBase"] def _copy_output(src: Graph, dst: Graph): """Copy the output node from src to dst""" for n_src, n_dst in zip(src.nodes, dst.nodes): - if n_src.op == 'output': + if n_src.op == "output": n_dst.meta = n_src.meta @@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module): class CheckpointSolverBase(ABC): - def __init__( self, graph: Graph, @@ -81,13 +80,10 @@ def __init__( @abstractmethod def solve(self): - """Solve the checkpointing problem and return the solution. - """ - pass + """Solve the checkpointing problem and return the solution.""" def get_node_list(self): - """Get the node list. - """ + """Get the node list.""" return [[node] for node in self.graph.nodes] def _linearize_graph(self) -> List[List[Node]]: @@ -140,8 +136,7 @@ def _is_sink() -> bool: """ def _is_inplace(n: Node): - """Get the inplace argument from ``torch.fx.Node`` - """ + """Get the inplace argument from ``torch.fx.Node``""" inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) @@ -150,19 +145,22 @@ def _is_inplace(n: Node): return inplace def _is_shape_consistency(n: Node): - """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) - """ + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)""" return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] - return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( - map(_is_shape_consistency, n.users)) + return ( + not sum([v for _, v in deps.items()]) + and not any(map(_is_inplace, n.users)) + and not any(map(_is_shape_consistency, n.users)) + ) # make sure that item in cnode is valid if self.cnode: for name in self.cnode: try: - assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ - f"Common node {name} is not an input of the model." + assert ( + next(node for node in self.graph.nodes if node.name == name).op == "placeholder" + ), f"Common node {name} is not an input of the model." except StopIteration: raise ValueError(f"Common node name {name} not in graph.") @@ -187,8 +185,9 @@ def _is_shape_consistency(n: Node): region = [] # propagate common node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode - ]) or _is_cop(n.target): + if len(n.all_input_nodes) == len( + [node for node in n.all_input_nodes if node.name in self.cnode] + ) or _is_cop(n.target): self.cnode.append(n.name) else: deps[n] = len([user for user in n.users if user.op != "output"]) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py index 19b2ef5987c9..ab16cc04b730 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py @@ -8,11 +8,10 @@ from .ckpt_solver_base import CheckpointSolverBase -__all__ = ['CheckpointSolverChen'] +__all__ = ["CheckpointSolverChen"] class CheckpointSolverChen(CheckpointSolverBase): - def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6): """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. @@ -40,14 +39,14 @@ def solve(self) -> Graph: Returns: graph (Graph): The optimized graph, should be a copy of the original graph. """ - checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr'] + checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"] ckpt = self.grid_search() for i, seg in enumerate(ckpt): for idx in range(*seg): nodes = self.node_list[idx] for n in nodes: if n.op in checkpointable_op: - n.meta['activation_checkpoint'] = i + n.meta["activation_checkpoint"] = i return deepcopy(self.graph) def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]: diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index 21c3bf0da758..d10c41ae2b96 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple from torch import Tensor from torch.fx import Graph, Node @@ -18,17 +18,18 @@ from .ckpt_solver_base import CheckpointSolverBase from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence -__all__ = ['CheckpointSolverRotor'] +__all__ = ["CheckpointSolverRotor"] class CheckpointSolverRotor(CheckpointSolverBase): - - def __init__(self, - graph: Graph, - free_memory: float = -1, - cnode: List[str] = None, - memory_slots: int = 500, - optim_multiplier: float = 1.0): + def __init__( + self, + graph: Graph, + free_memory: float = -1, + cnode: List[str] = None, + memory_slots: int = 500, + optim_multiplier: float = 1.0, + ): """This is the simple implementation of dynamic programming algorithm rotor in https://hal.inria.fr/hal-02352969. Some code are adapted from https://gitlab.inria.fr/hiepacs/rotor. @@ -85,13 +86,14 @@ def solve(self, force_python: bool = False, verbose: bool = False) -> Graph: # backtrack try: - self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, - self.back_ptr) + self.sequence = self._backtrack( + chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr + ) self._annotate_from_sequence(self.sequence, self.node_list) except ValueError as e: # using logger to annonce that the solver is failed logger = get_dist_logger() - logger.warning(f'Checkpoint solver failed: {e}') + logger.warning(f"Checkpoint solver failed: {e}") raise ValueError if verbose: @@ -100,14 +102,19 @@ def solve(self, force_python: bool = False, verbose: bool = False) -> Graph: return deepcopy(self.graph) def print_chain(self): - print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) + print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) for idx in range(len(self.node_list) - 1): - print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx], - self.chain.btmp[idx]) - print(f'Chain = {self.chain}') + print( + self.node_list[idx], + self.chain.x[idx + 1], + self.chain.xbar[idx + 1], + self.chain.ftmp[idx], + self.chain.btmp[idx], + ) + print(f"Chain = {self.chain}") def print_sequence(self): - print(f'Sequence = {self.sequence}') + print(f"Sequence = {self.sequence}") @classmethod def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: @@ -138,14 +145,14 @@ def _extract_node_info(cls, node: List[Node]) -> Tuple[int, ...]: btime = 0 fwd_mem_peak = 0 for n in node: - assert isinstance(n, Node), f'{n} is not a Node' + assert isinstance(n, Node), f"{n} is not a Node" if n.target == runtime_apply or n.target == runtime_comm_spec_apply: # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta - xbar += n.meta['fwd_mem_out'] - fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp']) + xbar += n.meta["fwd_mem_out"] + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"]) else: xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) - fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) + fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n)) # minimum flop count is required ftime += max(calculate_fwd_time(n), 1.0) @@ -162,14 +169,14 @@ def _extract_input(graph: Graph) -> Tuple[Tensor, ...]: """Extract input tensors from a Graph""" input_tensors = [] for node in graph.nodes: - if node.op == 'placeholder': - input_tensors.append(node.meta['fwd_out']) + if node.op == "placeholder": + input_tensors.append(node.meta["fwd_out"]) return input_tensors @staticmethod def _extract_unused_output(node: Node) -> int: """Extract unused output from `torch.fx.Node`""" - return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) + return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node) @staticmethod def _extract_btmp(node: List[Node]) -> int: @@ -180,8 +187,8 @@ def _extract_deps_size(): for k, v in deps.items(): k: Node if v > 0: - deps_size += k.meta['bwd_mem_out'] - if v == float('-inf'): + deps_size += k.meta["bwd_mem_out"] + if v == float("-inf"): deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) return deps_size @@ -190,12 +197,12 @@ def _extract_deps_size(): deps = {} for n in reversed(node): deps[n] = len(n.all_input_nodes) - btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp']) + btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"]) for child in n.users: if child in deps: deps[child] -= 1 if deps[child] <= 0: - deps[child] = float('-inf') # free + deps[child] = float("-inf") # free return btmp @staticmethod @@ -244,10 +251,11 @@ def _compute_table(chain: Chain, mmax: int) -> Tuple: if m < mmin: cost_table[m][i][idx] = float("inf") else: - leaf_checkpoints = [(j, - sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) - for j in range(i + 1, idx + 1) - if m >= x[j]] + leaf_checkpoints = [ + (j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) + for j in range(i + 1, idx + 1) + if m >= x[j] + ] if leaf_checkpoints: best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) else: @@ -274,13 +282,16 @@ def _compute_table_c(chain: Chain, mmax: int) -> Tuple: import os import subprocess import sys + logger = get_dist_logger() logger.info("rotorc hasn't been built! Building library...", ranks=[0]) this_dir = os.path.dirname(os.path.abspath(__file__)) result = subprocess.Popen( [ - f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", - f"--build-lib={this_dir}" + f"{sys.executable}", + f"{os.path.join(this_dir, 'build_c_ext.py')}", + "build_ext", + f"--build-lib={this_dir}", ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -294,8 +305,9 @@ def _compute_table_c(chain: Chain, mmax: int) -> Tuple: return compute_table(chain, mmax) @staticmethod - def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], - back_ptr: List[Any]) -> "Sequence": + def _backtrack( + chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any] + ) -> "Sequence": """Backtrack the cost table and retrieve the optimal checkpointing strategy. Args: @@ -328,8 +340,9 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A if back_ptr[budget][lhs][rhs][0]: sequence += [ ForwardEnable(lhs), - CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, - back_ptr), + CheckpointSolverRotor._backtrack( + chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr + ), Backward(lhs), ] else: @@ -337,8 +350,9 @@ def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[A sequence += [ForwardCheck(lhs)] sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)] sequence += [ - CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, - back_ptr), + CheckpointSolverRotor._backtrack( + chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr + ), CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr), ] return sequence @@ -353,8 +367,8 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): """ op_list = sequence.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - fwd_list = op_list[:op_list.index(loss_op)] - bwd_list = op_list[op_list.index(loss_op) + 1:] + fwd_list = op_list[: op_list.index(loss_op)] + bwd_list = op_list[op_list.index(loss_op) + 1 :] ckpt_idx = 0 in_ckpt = False ckpt_region = [] @@ -369,7 +383,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): in_ckpt = False for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'] = [ckpt_idx] + n.meta["activation_checkpoint"] = [ckpt_idx] ckpt_idx += 1 ckpt_region = [] @@ -377,7 +391,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'] = [ckpt_idx] + n.meta["activation_checkpoint"] = [ckpt_idx] ckpt_idx += 1 ckpt_region = [idx] @@ -397,7 +411,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): elif isinstance(op, ForwardEnable): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta["activation_checkpoint"].append(ckpt_idx) ckpt_idx += 1 ckpt_region = [] @@ -405,7 +419,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): elif isinstance(op, ForwardCheck): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta["activation_checkpoint"].append(ckpt_idx) ckpt_idx += 1 ckpt_region = [op.index] @@ -413,7 +427,7 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): elif isinstance(op, Backward): for node_idx in ckpt_region: for n in node_list[node_idx]: - n.meta['activation_checkpoint'].append(ckpt_idx) + n.meta["activation_checkpoint"].append(ckpt_idx) in_recompute = False @@ -431,9 +445,11 @@ def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): for node in node_list: op_list += node ckpt_regions = _find_nested_ckpt_regions(op_list) - for (start_idx, end_idx) in ckpt_regions: + for start_idx, end_idx in ckpt_regions: nested_length = max( - len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1)) + len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1) + ) for idx in range(start_idx, end_idx + 1): - op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length - - len(op_list[idx].meta['activation_checkpoint'])) + op_list[idx].meta["activation_checkpoint"] += [None] * ( + nested_length - len(op_list[idx].meta["activation_checkpoint"]) + ) diff --git a/colossalai/auto_parallel/checkpoint/operation.py b/colossalai/auto_parallel/checkpoint/operation.py index ab0c6c5ad38d..5f8077916433 100644 --- a/colossalai/auto_parallel/checkpoint/operation.py +++ b/colossalai/auto_parallel/checkpoint/operation.py @@ -1,20 +1,21 @@ import math from abc import ABC -from typing import Any, Iterable, List +from typing import List from torch.utils._pytree import tree_map class Chain: - - def __init__(self, - ftime: List[float], - btime: List[float], - x: List[int], - xbar: List[int], - ftmp: List[int], - btmp: List[int], - check_consistency: bool = True): + def __init__( + self, + ftime: List[float], + btime: List[float], + x: List[int], + xbar: List[int], + ftmp: List[int], + btmp: List[int], + check_consistency: bool = True, + ): """The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint. See paper https://hal.inria.fr/hal-02352969 for details. @@ -37,9 +38,14 @@ def __init__(self, raise AttributeError("In Chain, input lists do not have consistent lengths") def check_lengths(self): - return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1) - and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1) - and (len(self.xbar) == len(self) + 1)) + return ( + (len(self.ftime) == len(self)) + and (len(self.btime) == len(self) + 1) + and (len(self.x) == len(self) + 1) + and (len(self.ftmp) == len(self)) + and (len(self.btmp) == len(self) + 1) + and (len(self.xbar) == len(self) + 1) + ) def __repr__(self): chain_list = [] @@ -100,7 +106,6 @@ class ForwardCheck(Forward): class Forwards(Operation): - def __init__(self, start, end): self.index = (start, end) @@ -109,9 +114,9 @@ def __repr__(self): def cost(self, chain: Chain): if chain is not None: - return sum(chain.ftime[self.index[0]:self.index[1] + 1]) + return sum(chain.ftime[self.index[0] : self.index[1] + 1]) else: - return (self.index[1] - self.index[0] + 1) + return self.index[1] - self.index[0] + 1 def isForward(op): @@ -132,7 +137,6 @@ def cost(self, chain: Chain): class Loss(Operation): - def __init__(self): pass @@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess): class Sequence(list): - def __init__(self): super().__init__() diff --git a/colossalai/auto_parallel/meta_profiler/constants.py b/colossalai/auto_parallel/meta_profiler/constants.py index 35b8c13ee8ff..2f638fa919e4 100644 --- a/colossalai/auto_parallel/meta_profiler/constants.py +++ b/colossalai/auto_parallel/meta_profiler/constants.py @@ -3,8 +3,6 @@ import torch import torch.nn as nn -from ..tensor_shard.constants import * - # list of inplace module INPLACE_MODULE = [nn.ReLU] diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index 0f2e9e44f91c..4234481ae2ca 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0 def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: input_tensor = next( filter( - lambda x: - (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', - args)).data + lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) + and x.name != "softmax_dim", + args, + ) + ).data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data - is_inplace = 1 if kwargs.get('inplace', False) else 0 + is_inplace = 1 if kwargs.get("inplace", False) else 0 flop_counter = elementwise_flop_counter(1, 0) # calculate compute cost fwd_compute_cost = flop_counter([input_tensor], [output_tensor]) bwd_compute_cost = flop_counter([output_tensor], [input_tensor]) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) # calculate memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: if in_place is True, we will not create a new tensor in forward - fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace), - parameter=0, - temp=0, - buffer=activation_size(input_tensor) * buffer_mem_scale) + fwd_memory_cost = MemoryCost( + activation=activation_size(input_tensor) * (2 - is_inplace), + parameter=0, + temp=0, + buffer=activation_size(input_tensor) * buffer_mem_scale, + ) # temp_mem_scale is for situation like softmax backward # the buffer will be removed during backward phase @@ -54,20 +58,23 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale, parameter=0, temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale, - buffer=0) + buffer=0, + ) # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, - temp=fwd_memory_cost.temp + bwd_memory_cost.temp, - buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out fwd_in = [] - fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_buffer = [torch.zeros_like(output_tensor, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index e451748512b9..0b7b51a71955 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -6,10 +6,10 @@ from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION +from ..constants import BCAST_FUNC_OP from ..registry import meta_register -__all__ = ['binary_elementwise_meta_info'] +__all__ = ["binary_elementwise_meta_info"] @meta_register.register(BCAST_FUNC_OP) @@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train # store fwd_in, fwd_buffer, fwd_out fwd_in = [] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] + fwd_out = [torch.zeros_like(output_op_data.data, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index 4336bf68363c..2f630995cdbc 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -1,22 +1,14 @@ -from typing import Callable, Dict, List, Tuple, Union +from typing import List, Tuple import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from ..registry import meta_register -__all__ = ['convnd_meta_info'] +__all__ = ["convnd_meta_info"] @meta_register.register(torch.nn.Conv1d) @@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \ - flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) + bwd_compute_cost = ( + flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) + if has_bias + else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) + ) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost # TODO: use profiler to check conv temp memory # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias else compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) - - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) - if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) - if has_bias else compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) + + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias + else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py index d5d80f5b3700..7c9add810fd8 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py @@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor]) - bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor], - [weight_tensor]) + bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]( + [output_tensor, weight_tensor], [weight_tensor] + ) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) @@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=0, - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0 + ) bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0) total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index 94dd9143e0ae..d731f9cb4436 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -1,23 +1,15 @@ from functools import reduce -from typing import Callable, Dict, List, Tuple, Union +from typing import List, Tuple import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from ..registry import meta_register -__all__ = ['linear_meta_info', 'matmul_meta_info'] +__all__ = ["linear_meta_info", "matmul_meta_info"] @meta_register.register(torch.nn.functional.linear) @@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default]( - [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ - flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \ - flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,) + ) + bwd_compute_cost = ( + flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + + flop_mapping[torch.ops.aten.mm.default]( + [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,) + ) + + flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) + ) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) # calculate memory cost # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=0, + ) # the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=0) + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=0, + ) # total cost is to sum the forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) @@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate compute cost fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( - [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) - bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ - flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,) + ) + bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [output_tensor, weight_tensor], (input_tensor,) + ) + flop_mapping[torch.ops.aten.mm.default]( + [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,) + ) - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) # calculate memory cost # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), - parameter=compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) # the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]), - parameter=compute_size_in_bytes(weight_tensor), - temp=0, - buffer=0) + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0, + ) # total cost is to sum the forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out @@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # batched gemv case 1: batched matrix-vector multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors) + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors + ) # combine the dimensions of output bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( - [output_tensors[0].reshape(-1), input_tensors[1]], - output_tensors) + \ - flop_mapping[torch.ops.aten.matmul.default]( - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], - output_tensors) + [output_tensors[0].reshape(-1), input_tensors[1]], output_tensors + ) + flop_mapping[torch.ops.aten.matmul.default]( + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], + output_tensors, + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) @@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # gemv case 2: vector-matrix multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) - bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \ - flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) + bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( + [output_tensors[0], input_tensors[0]], output_tensors + ) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), - parameter=0, - temp=compute_size_in_bytes(input_tensors[1]), - buffer=0) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]), + buffer=0, + ) elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3: # batched gemv case 2: vector-batched matrix multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]], - [output_tensors[0].reshape(-1)]) + [output_tensors[0].reshape(-1)], + ) # combine the dimensions of output bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( - [output_tensors[0].reshape(-1), input_tensors[0]], - output_tensors - ) + \ - flop_mapping[torch.ops.aten.matmul.default]( - [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)], - output_tensors - ) + [output_tensors[0].reshape(-1), input_tensors[0]], output_tensors + ) + flop_mapping[torch.ops.aten.matmul.default]( + [ + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), + output_tensors[0].reshape(-1), + ], + output_tensors, + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]])) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), - parameter=0, - temp=compute_size_in_bytes(input_tensors[1]), - buffer=0) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors[0]), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]), + buffer=0, + ) elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2: # gemm & batched gemm case 1: batched matrix-matrix multiplication fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], - [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])]) + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])], + ) bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])], - [input_tensors[1]] - ) + \ - flop_mapping[torch.ops.aten.mm.default]( - [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)], - [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])] - ) + [ + input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), + output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), + ], + [input_tensors[1]], + ) + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)], + [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])], + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3: # batched gemm case 2: matrix-batched matrix multiplication - fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([ - input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose( - 0, 1) - ], [output_tensors[0].transpose(-2, -1)]) + fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( + [ + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), + input_tensors[0].transpose(0, 1), + ], + [output_tensors[0].transpose(-2, -1)], + ) bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( - [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])], - [input_tensors[0]] - ) + \ - flop_mapping[torch.ops.aten.mm.default]( - [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]], - [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])] - ) - - fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) + - compute_size_in_bytes(input_tensors[1]), - temp=compute_size_in_bytes(output_tensors)) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), - parameter=0, - temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors)) + [ + output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), + input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), + ], + [input_tensors[0]], + ) + flop_mapping[torch.ops.aten.mm.default]( + [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]], + [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])], + ) + + fwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]), + temp=compute_size_in_bytes(output_tensors), + ) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors[0]), + parameter=0, + temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors), + ) elif all(len(tensor.shape) >= 3 for tensor in input_tensors): # Batched matrix-batched matrix multiplication # Fetch shape of the two inputs and see if the batch dimensions are the same _is_batch_dims_same = True if len(input_tensors[0].shape) == len(input_tensors[1].shape): - for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]): + for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]): if shape_0 != shape_1: _is_batch_dims_same = False break @@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # Case 1: batch dimensions are the same # Forward compute cost: C = A * B - fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([ - input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape( - -1, input_dim_10, input_dim_11) - ], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( + [ + input_tensors[0].reshape(-1, input_dim_00, input_dim_01), + input_tensors[1].reshape(-1, input_dim_10, input_dim_11), + ], + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + ) # Backward compute cost: dB = A^T * dC, dA = dC * B^T bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( - [input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], - [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)] - ) + \ - flop_mapping[torch.ops.aten.bmm.default]( - [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)], - [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)] - ) + [ + input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), + output_tensors[0].reshape(-1, output_dim_0, output_dim_1), + ], + [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)], + ) + flop_mapping[torch.ops.aten.bmm.default]( + [ + output_tensors[0].reshape(-1, output_dim_0, output_dim_1), + input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10), + ], + [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)], + ) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors)) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors)) @@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L else: # Case 2: batch dimensions are different batch_dims = output_tensors[0].shape[:-2] - extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims), - input_dim_00, - input_dim_01, - device="meta") - extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims), - input_dim_10, - input_dim_11, - device="meta") + extended_input_0 = torch.rand( + reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta" + ) + extended_input_1 = torch.rand( + reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta" + ) # Forward compute cost: C = A * B fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( - [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) + [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)] + ) # Backward compute cost: dB = A^T * dC, dA = dC * B^T bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( - [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], - [extended_input_1] - ) + \ - flop_mapping[torch.ops.aten.bmm.default]( - [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)], - [extended_input_0] - ) + [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], + [extended_input_1], + ) + flop_mapping[torch.ops.aten.bmm.default]( + [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)], + [extended_input_0], + ) fwd_mem_cost = MemoryCost( - activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) - - compute_size_in_bytes([extended_input_0, extended_input_1]), - temp=compute_size_in_bytes([extended_input_0, extended_input_1])) + activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]) + ) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensors) + - compute_size_in_bytes([extended_input_0, extended_input_1]), + temp=compute_size_in_bytes([extended_input_0, extended_input_1]), + ) # compute cost compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # memory cost - total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, - parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, - temp=fwd_mem_cost.temp + bwd_mem_cost.temp, - buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + total_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py index 12874810b13e..b1bb1d872c35 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py @@ -3,7 +3,7 @@ import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from ..registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py index b872fdc8bdcd..99aaa752d0a1 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -1,22 +1,14 @@ -from typing import Callable, Dict, List, Tuple, Union +from typing import List, Tuple import torch from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from ..registry import meta_register -__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info'] +__all__ = ["batchnormnd_meta_info", "layernorm_meta_info"] @meta_register.register(torch.nn.BatchNorm1d) @@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt # saved inv std and some other args indicating the status of the module # the bwd outputs are input grad, weight grad and bias grad bwd_in_args = [ - output_tensor, output_tensor, weight_tensor, mean_tensor, var_tensor, mean_tensor, var_tensor, 1e-5, num_batch + output_tensor, + output_tensor, + weight_tensor, + mean_tensor, + var_tensor, + mean_tensor, + var_tensor, + 1e-5, + num_batch, ] bwd_out_args = [input_tensor, weight_tensor, bias_tensor] @@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt # calculate memory cost # the fwd activation cost is output plus saved mean and saved inv std # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( - [input_tensor, output_tensor, mean_tensor, var_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=compute_size_in_bytes([mean_tensor, var_tensor])) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor, mean_tensor, var_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=compute_size_in_bytes([mean_tensor, var_tensor]), + ) # the bwd memory cost is quite tricky here, BatchNorm will remove saved mean # and saved inv std during backward phase - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=compute_size_in_bytes([mean_tensor, var_tensor]), - buffer=compute_size_in_bytes([mean_tensor, var_tensor])) + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([mean_tensor, var_tensor]), + buffer=compute_size_in_bytes([mean_tensor, var_tensor]), + ) # total cost is the sum of forward and backward cost - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] - fwd_buffer = [torch.zeros_like(mean_tensor, device='meta'), torch.zeros_like(var_tensor, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] + fwd_buffer = [torch.zeros_like(mean_tensor, device="meta"), torch.zeros_like(var_tensor, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out @@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data weight_tensor = next(filter(lambda x: x.name == "weight", args)).data bias_tensor = next(filter(lambda x: x.name == "bias", args)).data - running_mean = torch.rand(input_tensor.shape[0], 1, device='meta') - running_var = torch.rand(input_tensor.shape[0], 1, device='meta') + running_mean = torch.rand(input_tensor.shape[0], 1, device="meta") + running_var = torch.rand(input_tensor.shape[0], 1, device="meta") # construct args fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor] @@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( - [input_tensor, output_tensor, weight_tensor, bias_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=0, - buffer=compute_size_in_bytes([running_mean, running_var])) - - bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), - parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), - temp=compute_size_in_bytes([running_mean, running_var]), - buffer=compute_size_in_bytes([running_mean, running_var])) - - total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, - parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, - temp=fwd_memory_cost.temp + bwd_memory_cost.temp, - buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) + fwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, output_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=0, + buffer=compute_size_in_bytes([running_mean, running_var]), + ) + + bwd_memory_cost = MemoryCost( + activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([running_mean, running_var]), + buffer=compute_size_in_bytes([running_mean, running_var]), + ) + + total_cost = MemoryCost( + activation=fwd_memory_cost.activation + bwd_memory_cost.activation, + parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, + temp=fwd_memory_cost.temp + bwd_memory_cost.temp, + buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] - fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] + fwd_buffer = [torch.zeros_like(running_mean, device="meta"), torch.zeros_like(running_var, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py index d785dfcca9ba..21aa524bed08 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, # store fwd_in, fwd_buffer, fwd_out fwd_in = [] fwd_buffer = [] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out @@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix])) # temp memory for backward is the index matrix to be discarded - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), - temp=compute_size_in_bytes(index_matrix)) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), + temp=compute_size_in_bytes(index_matrix), + ) # total cost total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp) @@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, mem_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) # store fwd_in, fwd_buffer, fwd_out - fwd_in = [torch.zeros_like(input_tensor, device='meta')] - fwd_buffer = [torch.zeros_like(index_matrix, device='meta')] - fwd_out = [torch.zeros_like(output_tensor, device='meta')] + fwd_in = [torch.zeros_like(input_tensor, device="meta")] + fwd_buffer = [torch.zeros_like(index_matrix, device="meta")] + fwd_out = [torch.zeros_like(output_tensor, device="meta")] return compute_cost, mem_cost, fwd_in, fwd_buffer, fwd_out diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py index 97fe3c6196f5..9a2df1bd7c87 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py @@ -2,7 +2,6 @@ import torch -from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem @@ -37,15 +36,19 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, - parameter=0, - temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, - buffer=0) + bwd_mem_cost = MemoryCost( + activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, + parameter=0, + temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, + buffer=0, + ) - total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, - parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, - temp=fwd_mem_cost.temp + bwd_mem_cost.temp, - buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + total_mem_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) @@ -66,14 +69,24 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor # register torch.Tensor related metainfo # (0, 0) -meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, - torch.arange])(tensor_related_metainfo(0, 0)) +meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze, torch.arange])( + tensor_related_metainfo(0, 0) +) # (1, 0) -meta_register.register([ - torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute, - torch.Tensor.split, torch.split, torch.Tensor.view -])(tensor_related_metainfo(1, 0)) +meta_register.register( + [ + torch.Tensor.flatten, + torch.flatten, + torch.Tensor.transpose, + torch.transpose, + torch.Tensor.permute, + torch.permute, + torch.Tensor.split, + torch.split, + torch.Tensor.view, + ] +)(tensor_related_metainfo(1, 0)) # (1, 1) meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1)) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py index 5cba1b5b6e2b..107851b80d7c 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py @@ -4,7 +4,7 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from ..registry import meta_register @@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li # gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase # NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor])) - bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]), - parameter=0, - temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) - - activation_size([x_tensor, y_tensor]), - buffer=0) - - total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, - parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, - temp=fwd_mem_cost.temp + bwd_mem_cost.temp, - buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + bwd_mem_cost = MemoryCost( + activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]), + parameter=0, + temp=activation_size([output_tensor]) * 3 + + activation_size([condition_tensor]) + - activation_size([x_tensor, y_tensor]), + buffer=0, + ) + + total_mem_cost = MemoryCost( + activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) diff --git a/colossalai/auto_parallel/meta_profiler/registry.py b/colossalai/auto_parallel/meta_profiler/registry.py index 46350c4dd406..c29086f7f9d1 100644 --- a/colossalai/auto_parallel/meta_profiler/registry.py +++ b/colossalai/auto_parallel/meta_profiler/registry.py @@ -1,14 +1,12 @@ -__all__ = ['Registry'] +__all__ = ["Registry"] class Registry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): if isinstance(source, (list, tuple)): # support register a list of items for this func @@ -21,7 +19,7 @@ def wrapper(func): return wrapper def get(self, source): - assert source in self.store, f'{source} not found in the {self.name} registry' + assert source in self.store, f"{source} not found in the {self.name} registry" target = self.store[source] return target @@ -29,4 +27,4 @@ def has(self, source): return source in self.store -meta_register = Registry('meta') +meta_register = Registry("meta") diff --git a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py index 0eee908b48b7..109b8a220ac7 100644 --- a/colossalai/auto_parallel/meta_profiler/shard_metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py @@ -2,20 +2,13 @@ import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem from colossalai.tensor.sharding_spec import ShardingSpec from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .registry import meta_register -__all__ = ['ShardMetaInfo'] +__all__ = ["ShardMetaInfo"] class ShardMetaInfo: @@ -76,10 +69,12 @@ def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: S """ if isinstance(sharding_spec, ShardingSpec): - op_data = OperationData(name=operation_data.name, - data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), - type=operation_data.type, - logical_shape=operation_data.logical_shape) + op_data = OperationData( + name=operation_data.name, + data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"), + type=operation_data.type, + logical_shape=operation_data.logical_shape, + ) elif isinstance(sharding_spec, (list, tuple)): data = operation_data.data assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}." @@ -97,8 +92,9 @@ def compute_shard_metainfo(self): """ Compute meta info based on sharding strategy and the given target function. """ - assert meta_register.has(self._target.__class__) or meta_register.has(self._target), \ - f"Meta info for {self._target} is not registered." + assert meta_register.has(self._target.__class__) or meta_register.has( + self._target + ), f"Meta info for {self._target} is not registered." if meta_register.has(self._target.__class__): # module meta_func = meta_register.get(self._target.__class__) @@ -117,11 +113,11 @@ def compute_shard_metainfo(self): # construct kwargs if self.target in INPLACE_MODULE: - kwargs = {'inplace': self.target.inplace} + kwargs = {"inplace": self.target.inplace} elif self.target in INPLACE_OPS: - kwargs = {'inplace': True} + kwargs = {"inplace": True} else: - kwargs = {'inplace': False} + kwargs = {"inplace": False} # compute metainfo with meta_func self.compute_cost, self.memory_cost, self.fwd_in, self.fwd_buffer, self.fwd_out = meta_func(*args, **kwargs) diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index 353133bd6f2d..601bf2926d99 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -37,19 +37,20 @@ class AMPOptimizer(OptimizerWrapper): norm_type (float, optional): norm_type used for `clip_grad_norm`. """ - def __init__(self, - optimizer: Optimizer, - module: BaseOffloadModule, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - clipping_norm: float = 0.0, - norm_type: float = 2.0): - + def __init__( + self, + optimizer: Optimizer, + module: BaseOffloadModule, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + clipping_norm: float = 0.0, + norm_type: float = 2.0, + ): super().__init__(optimizer) self.module = module @@ -69,19 +70,21 @@ def __init__(self, self.__init__optimizer() # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) self._logger = get_dist_logger() def _set_grad_ptr(self): for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: region = self.param_to_region[fake_param] begin, end = self.param_to_range[fake_param] @@ -92,7 +95,7 @@ def _set_grad_ptr(self): def _update_fp16_params(self): none_tensor = torch.empty([0]) for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: assert fake_param.grad is None fake_param.data = none_tensor self.param_to_region[fake_param].cpu_grad = None @@ -130,10 +133,10 @@ def step(self, *args, **kwargs): found_inf = self._check_overflow() if found_inf: - self.optim_state = OptimState.UNSCALED # no need to unscale grad - self.grad_scaler.update(found_inf) # update gradient scaler - self._logger.info(f'Found overflow. Skip step') - self.zero_grad() # reset all gradients + self.optim_state = OptimState.UNSCALED # no need to unscale grad + self.grad_scaler.update(found_inf) # update gradient scaler + self._logger.info(f"Found overflow. Skip step") + self.zero_grad() # reset all gradients self._update_fp16_params() return @@ -156,11 +159,10 @@ def backward(self, loss: torch.Tensor): self.module.backward(loss) def __init__optimizer(self): - for group in self.optim.param_groups: fake_params_list = list() - for param in group['params']: + for param in group["params"]: region = self.region_manager.get_region(param) fake_param = torch.nn.Parameter(torch.empty([0])) self.param_to_range[fake_param] = region.param_to_range[param] @@ -171,7 +173,7 @@ def __init__optimizer(self): if param in self.optim.state: self.optim.state[fake_param] = self.optim.state.pop(param) - group['params'] = fake_params_list + group["params"] = fake_params_list # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index 5b9f74b132f3..f5e8e31f5e97 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -22,7 +22,6 @@ class BaseOffloadModule: """ def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True): - self.model = model self.region_manager = region_manager self.grad_hook_list = [] @@ -91,17 +90,16 @@ def _cast_buffers(self): def parameters(self, recurse: bool = True): return self.model.parameters(recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True): + def named_parameters(self, prefix: str = "", recurse: bool = True): return self.model.named_parameters(prefix, recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True): + def named_buffers(self, prefix: str = "", recurse: bool = True): return self.model.named_buffers(prefix, recurse) def named_children(self): return self.model.named_children() - def named_modules(self, - memo: Optional[Set[torch.nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def named_modules( + self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): return self.model.named_modules(memo, prefix, remove_duplicate) diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py index d56166dea982..74501c184518 100644 --- a/colossalai/auto_parallel/offload/mem_optimize.py +++ b/colossalai/auto_parallel/offload/mem_optimize.py @@ -14,11 +14,9 @@ from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem -def memory_optimize(model: torch.nn.Module, - inps: Dict[str, torch.Tensor], - memory_budget: float = -1.0, - solver_name: str = 'asyn'): - +def memory_optimize( + model: torch.nn.Module, inps: Dict[str, torch.Tensor], memory_budget: float = -1.0, solver_name: str = "asyn" +): model = model.cpu().half() tracer = ColoTracer() assert is_compatible_with_meta() @@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module, f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}" ) - if solver_name == 'syn': + if solver_name == "syn": gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) - elif solver_name == 'asyn': + elif solver_name == "asyn": gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list) else: raise TypeError(f"Unknown solver name {solver_name}!") gm.recompile() - optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn') + optimized_model = BaseOffloadModule(gm, region_manager, solver_name == "syn") return optimized_model diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py index 819ffbd96eb1..ea92c714ce31 100644 --- a/colossalai/auto_parallel/offload/region.py +++ b/colossalai/auto_parallel/offload/region.py @@ -55,13 +55,13 @@ def init_param_data(self, pre_alloc_tensor: torch.Tensor = None): Map the parameters in the region to a contiguous memory space. """ - self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda') + self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device="cuda") offset = 0 for param in self.fp16_params: param.data = param.data.cuda() p_num = param.data.numel() - self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) - param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape) + self.fp16_data[offset : offset + p_num].copy_(param.data.flatten()) + param.data = self.fp16_data[offset : offset + p_num].view(param.data.shape) self.param_to_range[param] = (offset, offset + p_num) offset += p_num @@ -83,7 +83,7 @@ def move_param_to_cuda(self): self.temp_fp32_data.record_stream(torch.cuda.current_stream()) if not self.in_mem_pool_flag: alloc_storage(self.fp16_data) - self.fp16_data[:self.param_num].copy_(self.temp_fp32_data) + self.fp16_data[: self.param_num].copy_(self.temp_fp32_data) self.fp16_data.record_stream(torch.cuda.current_stream()) self.__update_params_ptr() @@ -94,7 +94,7 @@ def move_grad_to_cpu(self): """ self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True) - self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True) + self.cpu_grad.copy_(self.fp16_data[: self.param_num], non_blocking=True) self.fp16_data.record_stream(torch.cuda.current_stream()) if not self.in_mem_pool_flag: self.free_cuda_data() diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py index 30bfaf00d493..146dd267967d 100644 --- a/colossalai/auto_parallel/offload/region_manager.py +++ b/colossalai/auto_parallel/offload/region_manager.py @@ -1,10 +1,11 @@ -from typing import List, Any, Dict, Tuple +from typing import Any, Dict, List, Tuple + import torch from torch.fx import Graph, Node +from .region import Region from .solver import SolverFactory from .training_simulator import TrainingSimulator -from .region import Region from .util import NodeInfo @@ -19,14 +20,9 @@ class RegionManager: cnode (List[str], optional): Common node List, should be the subset of input. """ - def __init__(self, - graph: Graph, - solver_name: str = 'asyn', - memory_budget: float = -1.0, - cnode: List[str] = None): - + def __init__(self, graph: Graph, solver_name: str = "asyn", memory_budget: float = -1.0, cnode: List[str] = None): self.graph = graph - assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + assert graph.owning_module is not None, "The given graph is not associated with a owning_module" self.root_module = self.graph.owning_module self.nodes = list(graph.nodes) self.cnode = cnode @@ -39,7 +35,7 @@ def __init__(self, self.memory_budget = memory_budget self.solver_name = solver_name - self.require_pool: bool = solver_name == 'asyn' + self.require_pool: bool = solver_name == "asyn" self.reg_to_block: Dict[int, int] = dict() @@ -61,22 +57,19 @@ def _build_regions(self): self._post_process(solver.best_ts) def _pre_process(self): - init_region_list = self._linearize_graph() if len(self.shared_region_pairs) > 1: - raise NotImplementedError( - 'The current version only considers at most one pair of parameter sharing.') + raise NotImplementedError("The current version only considers at most one pair of parameter sharing.") elif len(self.shared_region_pairs) == 1: shared_regs = self.shared_region_pairs[0] - assert shared_regs[0].shared_rid == shared_regs[1].r_id \ - and shared_regs[1].shared_rid == shared_regs[0].r_id + assert shared_regs[0].shared_rid == shared_regs[1].r_id and shared_regs[1].shared_rid == shared_regs[0].r_id fst_id = shared_regs[0].r_id lst_id = shared_regs[1].r_id - regs_left_out = init_region_list[:fst_id + 1] + regs_left_out = init_region_list[: fst_id + 1] regs_right_out = init_region_list[lst_id:] - hold_regs = init_region_list[fst_id + 1:lst_id] + hold_regs = init_region_list[fst_id + 1 : lst_id] else: regs_left_out = [] regs_right_out = [] @@ -122,12 +115,9 @@ def _early_region_placement(self, ts: TrainingSimulator): it may not find a suitable region placement strategy for the given execution flow. """ - reg_flow = torch.cat( - [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) - mem_block_num = torch.max( - torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) - coexist_matrix = torch.logical_or( - ts.fwd_reg_flow, ts.bwd_reg_flow) + reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) + mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) + coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow) block_to_regs = {} for block_idx in range(mem_block_num): @@ -135,8 +125,7 @@ def _early_region_placement(self, ts: TrainingSimulator): for reg in self.region_list: if reg.r_id in self.rid_in_pool: cur_reg_appears = coexist_matrix[:, reg.r_id] - cur_reg_coexists = torch.sum( - coexist_matrix[cur_reg_appears], dim=0).bool() + cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool() for block_idx in range(mem_block_num): if not any(cur_reg_coexists[block_to_regs[block_idx]]): block_to_regs[block_idx].append(reg.r_id) @@ -145,9 +134,12 @@ def _early_region_placement(self, ts: TrainingSimulator): if reg.r_id not in self.reg_to_block: raise NotImplementedError( - f'can not find a block from the memory pool to store parameters of the region') - self.memory_pool = torch.chunk(torch.zeros(int( - mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num)) + f"can not find a block from the memory pool to store parameters of the region" + ) + self.memory_pool = torch.chunk( + torch.zeros(int(mem_block_num * self.mem_block_size / 2), dtype=torch.half, device="cuda"), + chunks=int(mem_block_num), + ) def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: """ @@ -178,10 +170,9 @@ def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: return region_list - def _search_block_size(self, - region_list: List[Region], - search_interval_byte: int = 1024, - search_range_byte: int = 128 * 1024 ** 2) -> int: + def _search_block_size( + self, region_list: List[Region], search_interval_byte: int = 1024, search_range_byte: int = 128 * 1024**2 + ) -> int: """ Search for a suitable memory block size. @@ -208,11 +199,10 @@ def _get_wasted_mem(size_list: List[int], blk_size: int): acc_wasted += blk_size - left return acc_wasted - param_size_list = [ - region.param_size for region in region_list if region.r_id == region.shared_rid] + param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid] start_size = max(param_size_list) - min_mem_waste = float('+inf') + min_mem_waste = float("+inf") best_block_size = start_size for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): @@ -229,7 +219,7 @@ def _init_region_data(self): Initialize region data, which maps the parameters in the region to a contiguous memory space. """ - self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32) + self.temp_fp32_data = torch.zeros(self.max_param_num, device="cuda", dtype=torch.float32) for region in self.region_list: pre_alloc_tensor = None @@ -244,8 +234,7 @@ def _init_region_data(self): region.fp16_data = shared_region.fp16_data region.fp32_data = shared_region.fp32_data region.param_to_range = shared_region.param_to_range - region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach( - ) + region.temp_fp32_data = self.temp_fp32_data[: region.param_num].detach() torch.cuda.empty_cache() @@ -259,13 +248,14 @@ def _process_shared_region(self): former_reg, latter_reg = self.shared_region_pairs[0] assert latter_reg.param_num >= former_reg.param_num embedding_node = former_reg.nodes[-1] - assert embedding_node.op == 'call_module' and isinstance( - self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding) + assert embedding_node.op == "call_module" and isinstance( + self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding + ) if latter_reg.param_num > former_reg.param_num: for idx, n in enumerate(latter_reg.nodes): - if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target), - torch.nn.Linear)) or \ - (n.op == 'call_function' and n.target is torch.nn.functional.linear): + if ( + n.op == "call_module" and isinstance(self.root_module.get_submodule(n.target), torch.nn.Linear) + ) or (n.op == "call_function" and n.target is torch.nn.functional.linear): cut_node_idx = idx + 1 break assert len(latter_reg.fp16_params) == 2 @@ -273,7 +263,7 @@ def _process_shared_region(self): for p in new_reg.fp16_params: self.param_region_map[p] = new_reg self.region_list.insert(new_reg.r_id, new_reg) - for reg in self.region_list[new_reg.r_id + 1:]: + for reg in self.region_list[new_reg.r_id + 1 :]: reg.r_id += 1 latter_reg.shared_rid = former_reg.r_id former_reg.shared_rid = latter_reg.r_id @@ -344,8 +334,8 @@ def _maybe_param_comp_start() -> bool: target = n.target submod = self.root_module.get_submodule(target) if ( - len(list(submod.named_parameters(recurse=False))) != 0 - or len(list(submod.named_buffers(recurse=False))) != 0 + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 ): label = True @@ -362,14 +352,12 @@ def _is_param_comp_end() -> bool: """ def _is_inplace(n: Node): - """Get the inplace argument from ``torch.fx.Node`` - """ + """Get the inplace argument from ``torch.fx.Node``""" inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) elif n.op == "call_module": - inplace = getattr(n.graph.owning_module.get_submodule( - n.target), "inplace", False) + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) return inplace label = False @@ -378,28 +366,30 @@ def _is_inplace(n: Node): target = n.target submod = self.root_module.get_submodule(target) if ( - len(list(submod.named_parameters(recurse=False))) != 0 - or len(list(submod.named_buffers(recurse=False))) != 0 + len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0 ): label = True elif n.op == "call_function": label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any( - map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes)) + map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes) + ) return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users)) def _exception_node_handling(): # TODO meta info prop bug - if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2: - n.meta['fwd_out'] = [] + if n.name.__contains__("transpose") and n.meta["fwd_out"][0].dim() <= 2: + n.meta["fwd_out"] = [] # make sure that item in cnode is valid if self.cnode: for name in self.cnode: try: - assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ - f"Common node {name} is not an input of the model." + assert ( + next(node for node in self.graph.nodes if node.name == name).op == "placeholder" + ), f"Common node {name} is not an input of the model." except StopIteration: raise ValueError(f"Common node name {name} not in graph.") else: @@ -428,8 +418,8 @@ def _exception_node_handling(): ns = [] border_n_idx = region.nodes.index(act_n) if border_n_idx < len(region.nodes): - ns = region.nodes[border_n_idx + 1:] - region.nodes = region.nodes[:border_n_idx + 1] + ns = region.nodes[border_n_idx + 1 :] + region.nodes = region.nodes[: border_n_idx + 1] region_list.append(region) region_id += 1 region = Region(r_id=region_id) @@ -448,19 +438,21 @@ def _exception_node_handling(): region = Region(r_id=region_id) # propagate common node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode - ]) or _is_cop(n.target): + if len(n.all_input_nodes) == len( + [node for node in n.all_input_nodes if node.name in self.cnode] + ) or _is_cop(n.target): self.cnode.append(n.name) else: - deps[n] = len( - [user for user in n.users if user.op != "output"]) + deps[n] = len([user for user in n.users if user.op != "output"]) # propagate param node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops - ]) or n.op == "get_attr": + if ( + len(n.all_input_nodes) + == len([node for node in n.all_input_nodes if node.name in self.only_param_ops]) + or n.op == "get_attr" + ): self.only_param_ops.append(n.name) - param_op_deps[n] = len( - [user for user in n.users if user.op != "output"]) + param_op_deps[n] = len([user for user in n.users if user.op != "output"]) # record last activation node if _is_act(n._meta_data): @@ -472,19 +464,16 @@ def _exception_node_handling(): return region_list def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): - cur_n.node_info = NodeInfo(node_id) - if cur_n.op == 'call_module': + if cur_n.op == "call_module": target = cur_n.target submod = self.root_module.get_submodule(target) for p in list(submod.parameters(recurse=False)): - if p in self.param_region_map: cur_reg.shared_rid = self.param_region_map[p].r_id self.param_region_map[p].shared_rid = cur_reg.r_id - self.shared_region_pairs.append( - (self.param_region_map[p], cur_reg)) + self.shared_region_pairs.append((self.param_region_map[p], cur_reg)) else: self.param_region_map[p] = cur_reg @@ -499,12 +488,10 @@ def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): attr_itr = getattr(attr_itr, atom) if isinstance(attr_itr, torch.nn.Parameter): - if attr_itr in self.param_region_map: cur_reg.shared_rid = self.param_region_map[attr_itr].r_id self.param_region_map[attr_itr].shared_rid = cur_reg.r_id - self.shared_region_pairs.append( - (self.param_region_map[attr_itr], cur_reg)) + self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg)) else: self.param_region_map[attr_itr] = cur_reg diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py index 764ac608826b..cc790dfb0891 100644 --- a/colossalai/auto_parallel/offload/runtime.py +++ b/colossalai/auto_parallel/offload/runtime.py @@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function): @staticmethod def forward(ctx, input_, fwd_info, bwd_info): ctx.bwd_info = bwd_info - d2h_rid = fwd_info.get('d2h_rid', None) + d2h_rid = fwd_info.get("d2h_rid", None) if d2h_rid is not None: free_region = GlobalRuntimeInfo().region_list[d2h_rid] assert isinstance(free_region, Region) free_region.free_cuda_data() - h2d_rid = fwd_info.get('h2d_rid', None) + h2d_rid = fwd_info.get("h2d_rid", None) if h2d_rid is not None: h2d_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(h2d_region, Region) @@ -38,8 +38,7 @@ def forward(ctx, input_, fwd_info, bwd_info): @staticmethod def backward(ctx, grad_output): - - h2d_rid = ctx.bwd_info.get('h2d_rid', None) + h2d_rid = ctx.bwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) @@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function): def forward(ctx, input_, fwd_info, bwd_info): ctx.bwd_info = bwd_info - sync_rid = fwd_info.get('sync_rid', None) + sync_rid = fwd_info.get("sync_rid", None) if sync_rid is not None: prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None) if prefetch_event: prefetch_event.wait() - h2d_rid = fwd_info.get('h2d_rid', None) + h2d_rid = fwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) @@ -87,8 +86,7 @@ def forward(ctx, input_, fwd_info, bwd_info): @staticmethod def backward(ctx, grad_output): - - sync_rid = ctx.bwd_info.get('sync_rid', None) + sync_rid = ctx.bwd_info.get("sync_rid", None) if sync_rid is not None: wait_region = GlobalRuntimeInfo().region_list[sync_rid] assert isinstance(wait_region, Region) @@ -98,7 +96,7 @@ def backward(ctx, grad_output): else: wait_region.move_param_to_cuda() - h2d_rid = ctx.bwd_info.get('h2d_rid', None) + h2d_rid = ctx.bwd_info.get("h2d_rid", None) if h2d_rid is not None: pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) @@ -114,7 +112,7 @@ def backward(ctx, grad_output): def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): - ''' + """ Convert Upload and Offload operation into runtime action. Argument: @@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): that need to be uploaded, or freed during forward pass. bwd_info(dict): information dict, which contains region indices that need to be uploaded during backward pass. - ''' + """ with torch._C.DisableTorchFunction(): ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): - ''' + """ Convert Prefetch and Offload operation into runtime action. Argument: @@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): that need to be prefetched, waited, or freed during forward pass. bwd_info(dict): information dict, which contains region indices that need to be prefetched or waited during backward pass. - ''' + """ with torch._C.DisableTorchFunction(): ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret @@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R # forward upload fwd_info = {} if requires_upload_p_in_fwd(region_list[region.shared_rid]): - fwd_info['h2d_rid'] = region.r_id + fwd_info["h2d_rid"] = region.r_id # forward offload if r_idx > 0 and region_list[r_idx - 1].need_offload: - fwd_info['d2h_rid'] = r_idx - 1 + fwd_info["d2h_rid"] = r_idx - 1 bwd_info = {} # backward upload if r_idx > 0 and region_list[r_idx - 1].need_offload: - bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id + bwd_info["h2d_rid"] = region_list[r_idx - 1].r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', - convert_fwd_upload_bwd_offload_to_action, - args=(last_inp_node, fwd_info, bwd_info)) + new_node = mod_graph.create_node( + "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info) + ) replace_node_users(last_inp_node, new_node) last_inp_node = region.nodes[-1] @@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ first_region_with_p = [region for region in region_list if region.param_size][0] fwd_info = {"h2d_rid": first_region_with_p.r_id} with mod_graph.inserting_after(last_inp_node): - upload_apply_node = mod_graph.create_node('call_function', - convert_fwd_upload_bwd_offload_to_action, - args=(last_inp_node, fwd_info, {})) + upload_apply_node = mod_graph.create_node( + "call_function", convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {}) + ) replace_node_users(last_inp_node, upload_apply_node) last_inp_node = upload_apply_node @@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ # forward prefetch fwd_info = {} if region.param_size: - fwd_info['sync_rid'] = region.r_id + fwd_info["sync_rid"] = region.r_id fwd_prefetch_region = region.fwd_prefetch_region if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]): - fwd_info['h2d_rid'] = fwd_prefetch_region.r_id + fwd_info["h2d_rid"] = fwd_prefetch_region.r_id # forward offload if r_idx > 0 and region_list[r_idx - 1].need_offload: - fwd_info['d2h_rid'] = r_idx - 1 + fwd_info["d2h_rid"] = r_idx - 1 bwd_info = {} # backward prefetch if r_idx > 0 and region_list[r_idx - 1].need_offload: - bwd_info['sync_rid'] = r_idx - 1 + bwd_info["sync_rid"] = r_idx - 1 if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region: - bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id + bwd_info["h2d_rid"] = region_list[r_idx - 1].bwd_prefetch_region.r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', - convert_fwd_prefetch_bwd_offload_to_action, - args=(last_inp_node, fwd_info, bwd_info)) + new_node = mod_graph.create_node( + "call_function", + convert_fwd_prefetch_bwd_offload_to_action, + args=(last_inp_node, fwd_info, bwd_info), + ) replace_node_users(last_inp_node, new_node) last_inp_node = region.nodes[-1] if region.bwd_prefetch_region: - bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} + bwd_info = {"h2d_rid": region.bwd_prefetch_region.r_id} with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', - convert_fwd_prefetch_bwd_offload_to_action, - args=(last_inp_node, {}, bwd_info)) + new_node = mod_graph.create_node( + "call_function", convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info) + ) replace_node_users(last_inp_node, new_node) # gm.graph.print_tabular() return gm diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index 161f7ff86898..a6b4904f2617 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -1,6 +1,6 @@ import time -from typing import List, Dict, Type from abc import ABC, abstractmethod +from typing import Dict, List, Type NOT_NVML = False try: @@ -10,10 +10,11 @@ import torch from torch.fx.node import Node + from colossalai.utils.cuda import get_current_device -from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator from .region import Region +from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator from .util import NodeInfo, NvDevicePower @@ -49,19 +50,14 @@ class Solver(ABC): It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time. """ - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0, - error_factor: float = 0.95) -> None: - + def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None: self.region_list = region_list self.error_factor: float = error_factor if memory_budget > 0: self.memory_budget = memory_budget * self.error_factor else: - self.memory_budget = torch.cuda.get_device_properties( - get_current_device()).total_memory * self.error_factor + self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.comp_power: float = self._extract_computing_power() @@ -94,7 +90,7 @@ def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: floa if extra_cost == 0: # means data transfer overhead can be completely overlapped - return (float('inf'), total_mem_saving, peak_mem_saving) + return (float("inf"), total_mem_saving, peak_mem_saving) return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving) def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool: @@ -122,9 +118,7 @@ def _update_state(self, best_ts: TrainingSimulator): self.best_ts = best_ts self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem) - def _update_node_mem_info(self, - fwd_mem_info: Dict[Node, float], - bwd_mem_info: Dict[Node, float]): + def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]): """ Update the runtime memory information of the node. @@ -134,12 +128,10 @@ def _update_node_mem_info(self, """ for node, mem in fwd_mem_info.items(): - assert hasattr(node, 'node_info') and isinstance( - node.node_info, NodeInfo) + assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo) node.node_info.runtime_fwd_mem = mem for node, mem in bwd_mem_info.items(): - assert hasattr(node, 'node_info') and isinstance( - node.node_info, NodeInfo) + assert hasattr(node, "node_info") and isinstance(node.node_info, NodeInfo) node.node_info.runtime_bwd_mem = mem def _extract_computing_power(self): @@ -159,12 +151,12 @@ def _extract_computing_power(self): return NvDevicePower.RTX3080_FP16 * units elif device_name.__contains__("RTX 3090"): return NvDevicePower.RTX3090_FP16 * units - elif device_name.__contains__('V100'): + elif device_name.__contains__("V100"): return NvDevicePower.V100_FP16 * units elif device_name.__contains__("A100"): return NvDevicePower.A100_FP16 * units else: - raise TypeError(f'Unknown NVIDIA GPU device name {device_name}') + raise TypeError(f"Unknown NVIDIA GPU device name {device_name}") def _profile_bandwidth(self): """ @@ -172,9 +164,9 @@ def _profile_bandwidth(self): using data volumes ranging from 1KB to 1GB. """ - print('profiling bandwidth ......') + print("profiling bandwidth ......") link_to_bandwidth = {} - links = ['h2d', 'd2h'] + links = ["h2d", "d2h"] for link in links: t_size = 1024 @@ -182,24 +174,22 @@ def _profile_bandwidth(self): # from 1KB to 1GB for i in range(21): - if link == 'h2d': - src_tensor = torch.ones( - int(t_size), dtype=torch.int8, pin_memory=True) - dst_tensor = torch.ones( - (int(t_size)), dtype=torch.int8, device='cuda') - elif link == 'd2h': - src_tensor = torch.ones( - int(t_size), dtype=torch.int8, device='cuda') - dst_tensor = torch.ones( - (int(t_size)), dtype=torch.int8, pin_memory=True) + if link == "h2d": + src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True) + dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device="cuda") + elif link == "d2h": + src_tensor = torch.ones(int(t_size), dtype=torch.int8, device="cuda") + dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True) def func(): dst_tensor.copy_(src_tensor) size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3) - print(f'size: {t_size / 1024 ** 2:.3f} MB, ' - f'{src_tensor.device.type}-to-{dst_tensor.device.type} ' - f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s') + print( + f"size: {t_size / 1024 ** 2:.3f} MB, " + f"{src_tensor.device.type}-to-{dst_tensor.device.type} " + f"bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s" + ) t_size *= 2 @@ -208,10 +198,7 @@ def func(): class SynGreedySolver(Solver): - - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0) -> None: + def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None: super().__init__(region_list, memory_budget) self.best_ts: SynTrainingSimulator = None @@ -258,7 +245,8 @@ def _call_solver(self): else: raise NotImplementedError( f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " - f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!" + ) def _call_solver_l2l(self): """ @@ -270,7 +258,6 @@ def _call_solver_l2l(self): region.is_syn = True def _try_to_offload(self, offload_region: Region): - # record previous information orig_need_offload = offload_region.need_offload assert not orig_need_offload @@ -297,23 +284,17 @@ def _eval_one_choice(self, offload_region: Region): ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() - extra_comm_cost = 2.0 * \ - ts._get_communication_overhead('h2d', offload_region.param_size) + extra_comm_cost = 2.0 * ts._get_communication_overhead("h2d", offload_region.param_size) # the shared region needs to be moved twice if offload_region.r_id < offload_region.shared_rid: extra_comm_cost *= 2.0 - profit = self._compute_offload_profit( - ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class AsynGreedySolver(Solver): - - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0, - search_window_size: int = 3): + def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3): super().__init__(region_list, memory_budget) self.search_window_size = search_window_size @@ -331,7 +312,7 @@ def _init_state(self): ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() self._update_state(ts) - print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB") + print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB") def _call_solver(self): """ @@ -358,18 +339,17 @@ def _call_solver(self): best_pref_ts = None # search when to prefetch the region offloaded - for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]: + for host_region in self.region_list[region.r_id + 1 : region.r_id + 1 + self.search_window_size]: if host_region.bwd_prefetch_region is not None: continue - temp_ts, profit = self._try_to_offload( - host_region, region) + temp_ts, profit = self._try_to_offload(host_region, region) if self._compare_profit(profit, max_prefetch_profit): region_to_region_map[region.r_id] = host_region max_prefetch_profit = profit best_pref_ts = temp_ts - if profit[0] == float('inf'): + if profit[0] == float("inf"): break if self._compare_profit(max_prefetch_profit, max_offload_profit): @@ -392,7 +372,8 @@ def _call_solver(self): else: raise NotImplementedError( f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, " - f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!") + f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!" + ) region_to_region_map.clear() @@ -452,7 +433,6 @@ def _repair_strategy(self): peak_mem_saving = 0 while len(self.region_to_region_map) and peak_mem_saving <= 0: - max_profit = (0,) best_ts = None undo_host_region = None @@ -464,8 +444,7 @@ def _repair_strategy(self): assert offload_region.need_offload assert not offload_region.is_syn - ts, profit = self._try_convert_to_syn_upload(host_region, - offload_region) + ts, profit = self._try_convert_to_syn_upload(host_region, offload_region) if self._compare_profit(profit, max_profit): undo_host_region = host_region @@ -474,7 +453,7 @@ def _repair_strategy(self): best_ts = ts if best_ts is None: - raise NotImplementedError('repair error!') + raise NotImplementedError("repair error!") assert not undo_offload_region.is_syn undo_offload_region.is_syn = True @@ -500,17 +479,13 @@ def _eval_one_choice(self): ts.execute() extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0) - profit = self._compute_offload_profit( - ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class SolverFactory: - solvers: Dict[str, Type[Solver]] = { - 'syn': SynGreedySolver, - 'asyn': AsynGreedySolver - } + solvers: Dict[str, Type[Solver]] = {"syn": SynGreedySolver, "asyn": AsynGreedySolver} @staticmethod def create(solver_name: str) -> Type[Solver]: diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py index de58023ec2d6..728d8daf9a46 100644 --- a/colossalai/auto_parallel/offload/training_simulator.py +++ b/colossalai/auto_parallel/offload/training_simulator.py @@ -1,7 +1,7 @@ import bisect -from typing import List, Dict -from collections import OrderedDict from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Dict, List from torch.fx.node import Node @@ -26,10 +26,7 @@ class TrainingSimulator(ABC): link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth. """ - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: self.region_list = region_list self.region_num = len(region_list) @@ -87,11 +84,7 @@ def _get_computing_overhead(self, flop: float) -> float: class SynTrainingSimulator(TrainingSimulator): - - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: super().__init__(region_list, comp_power, link_to_bw) def execute(self): @@ -115,8 +108,7 @@ def _eval_fwd_mem_per_region(self, region: Region): self.runtime_mem += region.param_size for node in region.nodes: - self.runtime_mem += calculate_fwd_tmp(node) + \ - calculate_fwd_out(node) + self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node) self.fwd_node_mem[node] = self.runtime_mem self.peak_mem = max(self.runtime_mem, self.peak_mem) self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem @@ -141,18 +133,15 @@ def _eval_bwd_mem_per_region(self, region: Region): self.runtime_mem += region.param_size for node in region.nodes.__reversed__(): - self.runtime_mem -= calculate_fwd_out(node) - self.runtime_mem += node.meta['bwd_mem_tmp'] + \ - node.meta['bwd_mem_out'] + self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] self.peak_mem = max(self.runtime_mem, self.peak_mem) # The memory savings of a node may be negative due to parameter prefetch. self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem self.bwd_node_mem[node] = self.runtime_mem - self.runtime_mem -= (node.meta['bwd_mem_tmp'] + - calculate_fwd_tmp(node)) + self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node) # free bwd_mem_out self.bwd_node_deps[node] = len(node.all_input_nodes) @@ -160,12 +149,14 @@ def _eval_bwd_mem_per_region(self, region: Region): if user_node in self.bwd_node_deps: self.bwd_node_deps[user_node] -= 1 if self.bwd_node_deps[user_node] <= 0: - self.runtime_mem -= user_node.meta['bwd_mem_out'] + self.runtime_mem -= user_node.meta["bwd_mem_out"] if self.runtime_mem < 0: - raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " - f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" - f"runtime memory computed less than 0, which is miscalculated!") + raise ValueError( + f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!" + ) # release parameter and offload gradient in region if region.r_id == region.shared_rid: @@ -177,23 +168,16 @@ def _eval_bwd_mem_per_region(self, region: Region): class AsynTrainingSimulator(TrainingSimulator): - - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: super().__init__(region_list, comp_power, link_to_bw) self.iter_end_time: int = 0 # the last computation execution period - self.last_comp: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the last parameter prefetch execution period - self.last_h2d: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the last gradient offload execution period - self.last_d2h: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the forward computation execution period of the region self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the forward parameter prefetch execution period of the region @@ -204,10 +188,8 @@ def __init__(self, self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the gradient offload execution period of the region # which is divided into those that are waiting and those that have been released - self.bwd_reg_to_offl_waiting: OrderedDict[int, - ExecutionPeriod] = OrderedDict() - self.bwd_reg_to_offl_freed: OrderedDict[int, - ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the region buffer, which records regions that are offloaded but not released self.reg_buffer_to_free: List[int] = [] @@ -217,10 +199,8 @@ def __init__(self, # the region execution flow, # where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU # when the execution reaches the i-th region. - self.fwd_reg_flow = torch.zeros( - (self.region_num, self.region_num)).bool() - self.bwd_reg_flow = torch.zeros( - (self.region_num, self.region_num)).bool() + self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool() + self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool() def execute(self): """ @@ -232,7 +212,7 @@ def execute(self): for reg in self.region_list: if reg.param_size and reg.r_id < self.region_num - 1: - for nr in self.region_list[reg.r_id + 1:]: + for nr in self.region_list[reg.r_id + 1 :]: if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]): reg.fwd_prefetch_region = nr break @@ -249,8 +229,7 @@ def execute(self): self.runtime_mem -= self.region_list[reg_id].param_size self.bwd_reg_to_offl_waiting.clear() - self.iter_end_time = max( - self.last_comp.end_time, self.last_d2h.end_time) + self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time) def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): """ @@ -258,10 +237,8 @@ def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): """ pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time) - pref_end_time = pref_start_time + \ - 2.0 * self._get_communication_overhead('h2d', region.param_size) - pref_ep = ExecutionPeriod( - start_time=pref_start_time, end_time=pref_end_time) + pref_end_time = pref_start_time + 2.0 * self._get_communication_overhead("h2d", region.param_size) + pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time) if is_fwd: self.fwd_reg_to_pref[region.r_id] = pref_ep else: @@ -276,18 +253,16 @@ def _insert_comp_exec(self, region: Region, is_fwd: bool = True): if is_fwd: reg_to_comp = self.fwd_reg_to_comp reg_to_pref = self.fwd_reg_to_pref - flop_key = 'fwd_flop' + flop_key = "fwd_flop" else: reg_to_comp = self.bwd_reg_to_comp reg_to_pref = self.bwd_reg_to_pref - flop_key = 'bwd_flop' - comp_start_time = max(self.last_comp.end_time, reg_to_pref.get( - region.r_id, ExecutionPeriod(0, 0)).end_time) - comp_end_time = comp_start_time + \ - sum([self._get_computing_overhead(node.meta.get(flop_key, 0)) - for node in region.nodes]) - comp_ep = ExecutionPeriod( - start_time=comp_start_time, end_time=comp_end_time) + flop_key = "bwd_flop" + comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time) + comp_end_time = comp_start_time + sum( + [self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes] + ) + comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time) reg_to_comp[region.r_id] = comp_ep self.last_comp = comp_ep @@ -297,10 +272,8 @@ def _insert_d2h_exec(self, region: Region): """ offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time) - offl_end_time = offl_start_time + \ - self._get_communication_overhead('d2h', region.param_size) - offl_ep = ExecutionPeriod( - start_time=offl_start_time, end_time=offl_end_time) + offl_end_time = offl_start_time + self._get_communication_overhead("d2h", region.param_size) + offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time) self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep self.last_d2h = offl_ep @@ -332,20 +305,17 @@ def _eval_fwd_mem_per_region(self, region: Region): self.fwd_reg_flow[region.r_id, region.r_id] = True else: self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1] - self.fwd_reg_flow[region.r_id, - self.reg_buffer_to_free] = False + self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False self.reg_buffer_to_free.clear() # prefetch parameters of the next region fwd_prefetch_region = region.fwd_prefetch_region if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): self.runtime_mem += fwd_prefetch_region.param_size - self.fwd_reg_flow[region.r_id, - fwd_prefetch_region.r_id] = True + self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True for node in region.nodes: - self.runtime_mem += calculate_fwd_tmp(node) + \ - calculate_fwd_out(node) + self.runtime_mem += calculate_fwd_tmp(node) + calculate_fwd_out(node) self.peak_mem = max(self.runtime_mem, self.peak_mem) self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem @@ -354,8 +324,7 @@ def _eval_fwd_mem_per_region(self, region: Region): if region.need_offload: self.runtime_mem -= region.param_size - assert len( - self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}' + assert len(self.reg_buffer_to_free) <= 1, f"{len(self.reg_buffer_to_free)}" self.reg_buffer_to_free.append(region.r_id) def _eval_bwd_cost_per_region(self, region: Region): @@ -398,8 +367,7 @@ def _eval_bwd_mem_per_region(self, region: Region): self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1] else: self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1] - self.bwd_reg_flow[region.r_id, - self.reg_buffer_to_free] = False + self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False # free gradients in the buffer while len(self.reg_buffer_to_free): @@ -415,8 +383,7 @@ def _eval_bwd_mem_per_region(self, region: Region): bwd_prefetch_region = region.bwd_prefetch_region if bwd_prefetch_region: self.runtime_mem += bwd_prefetch_region.param_size - self.bwd_reg_flow[region.r_id, - bwd_prefetch_region.r_id] = True + self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True # add the gradient of the parameter if region.r_id < region.shared_rid: @@ -426,10 +393,8 @@ def _eval_bwd_mem_per_region(self, region: Region): self.runtime_mem += region.param_size for node in region.nodes.__reversed__(): - self.runtime_mem -= calculate_fwd_out(node) - self.runtime_mem += node.meta['bwd_mem_tmp'] + \ - node.meta['bwd_mem_out'] + self.runtime_mem += node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] self.peak_mem = max(self.runtime_mem, self.peak_mem) # The memory savings of a node may be negative due to parameter prefetch. @@ -437,8 +402,7 @@ def _eval_bwd_mem_per_region(self, region: Region): self.bwd_node_mem[node] = self.runtime_mem - self.runtime_mem -= (node.meta['bwd_mem_tmp'] + - calculate_fwd_tmp(node)) + self.runtime_mem -= node.meta["bwd_mem_tmp"] + calculate_fwd_tmp(node) # free bwd_mem_out self.bwd_node_deps[node] = len(node.all_input_nodes) @@ -446,12 +410,14 @@ def _eval_bwd_mem_per_region(self, region: Region): if user_node in self.bwd_node_deps: self.bwd_node_deps[user_node] -= 1 if self.bwd_node_deps[user_node] <= 0: - self.runtime_mem -= user_node.meta['bwd_mem_out'] + self.runtime_mem -= user_node.meta["bwd_mem_out"] if self.runtime_mem < 0: - raise ValueError(f"region id: {region.r_id}, node name: {node.name}, " - f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" - f"runtime memory computed less than 0, which is miscalculated!") + raise ValueError( + f"region id: {region.r_id}, node name: {node.name}, " + f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---" + f"runtime memory computed less than 0, which is miscalculated!" + ) # release parameters of the region if requires_release_p_in_bwd(self.region_list[region.shared_rid]): diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py index 6b010512cc9c..cb65da79c5a2 100644 --- a/colossalai/auto_parallel/offload/util.py +++ b/colossalai/auto_parallel/offload/util.py @@ -35,7 +35,6 @@ class NvDevicePower: class GlobalRuntimeInfo(metaclass=SingletonMeta): - def __init__(self): self.h2d_stream = torch.cuda.Stream() self.d2h_stream = torch.cuda.Stream() @@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: # forward for region in region_list: for node in region.nodes: - runtime_mem = runtime_mem + \ - calculate_fwd_tmp(node) + calculate_fwd_out(node) + runtime_mem = runtime_mem + calculate_fwd_tmp(node) + calculate_fwd_out(node) act_peak_mem = max(runtime_mem, act_peak_mem) # backward bwd_deps = {} for region in region_list.__reversed__(): for node in region.nodes.__reversed__(): runtime_mem -= calculate_fwd_out(node) - runtime_mem = runtime_mem + \ - node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out'] + runtime_mem = runtime_mem + node.meta["bwd_mem_tmp"] + node.meta["bwd_mem_out"] act_peak_mem = max(runtime_mem, act_peak_mem) - runtime_mem = runtime_mem - \ - node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node) + runtime_mem = runtime_mem - node.meta["bwd_mem_tmp"] - calculate_fwd_tmp(node) # free bwd_mem_out bwd_deps[node] = len(node.all_input_nodes) @@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: if user_node in bwd_deps: bwd_deps[user_node] -= 1 if bwd_deps[user_node] <= 0: - runtime_mem -= user_node.meta['bwd_mem_out'] + runtime_mem -= user_node.meta["bwd_mem_out"] return act_peak_mem @@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float: def requires_upload_p_in_fwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid - and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or ( + shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload + ) def requires_release_p_in_bwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid - and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or ( + shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload + ) def requires_offload_g_in_bwd(region: Region): diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index ffda58e0689f..ba290ee839d8 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -14,18 +14,20 @@ shape_consistency_manager = ShapeConsistencyManager() -def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, - target_sharding_spec: ShardingSpec) -> ShardMetaInfo: +def _construct_shard_meta_info( + node: Node, origin_sharding_spec: ShardingSpec, target_sharding_spec: ShardingSpec +) -> ShardMetaInfo: # get comm_action_sequence and total_cost from shape_consistency_manager _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( - origin_sharding_spec, target_sharding_spec) + origin_sharding_spec, target_sharding_spec + ) meta_info = ShardMetaInfo() # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel # get mem cost for ShardMetaInfo mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) # extract user that has _meta_data and extract element length - input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) + input_node = next(n for n in node._input_nodes if hasattr(n, "_meta_data")) element_length = input_node._meta_data.element_size() mem_cost.fwd.activation *= element_length @@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, meta_info.memory_cost = mem_cost # get computation cost for ShardMetaInfo - meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, - total_cost['backward'] * element_length, - total_cost['total'] * element_length) + meta_info.compute_cost = TrainCycleItem( + total_cost["forward"] * element_length, + total_cost["backward"] * element_length, + total_cost["total"] * element_length, + ) # get tensor shape for ShardMetaInfo origin_sharding_spec: ShardingSpec @@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, input_shape = origin_sharding_spec.get_sharded_shape_per_device() output_shape = target_sharding_spec.get_sharded_shape_per_device() - meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_in = [torch.rand(input_shape, device="meta")] meta_info.fwd_buffer = [] - meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + meta_info.fwd_out = [torch.rand(output_shape, device="meta")] return meta_info @@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) - # extract node index and user node index args = node.args node_index, user_node_index = args[3], args[4] - origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ - user_node_index] + origin_sharding_spec, target_sharding_spec = ( + origin_spec_dict[node_index], + sharding_spec_dict[node_index][user_node_index], + ) return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) @@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S # this case is for all_reduce, there will be no memory cost meta_info = ShardMetaInfo() meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) - output_node = next(n for n in node.users if hasattr(n, '_meta_data')) + output_node = next(n for n in node.users if hasattr(n, "_meta_data")) element_length = output_node._meta_data.element_size() total_cost = comm_action.comm_spec.get_comm_cost() - meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, - total_cost['backward'] * element_length, - total_cost['total'] * element_length) + meta_info.compute_cost = TrainCycleItem( + total_cost["forward"] * element_length, + total_cost["backward"] * element_length, + total_cost["total"] * element_length, + ) input_shape = output_shape = comm_action.comm_spec.sharding_spec.get_sharded_shape_per_device() - meta_info.fwd_in = [torch.rand(input_shape, device='meta')] + meta_info.fwd_in = [torch.rand(input_shape, device="meta")] meta_info.fwd_buffer = [] - meta_info.fwd_out = [torch.rand(output_shape, device='meta')] + meta_info.fwd_out = [torch.rand(output_shape, device="meta")] else: # this case will be handled by shape consistency manager - origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ - 'tgt_spec'] + origin_sharding_spec, target_sharding_spec = ( + comm_action.comm_spec["src_spec"], + comm_action.comm_spec["tgt_spec"], + ) meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) return meta_info -def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, - comm_actions_dict: Dict) -> GraphModule: +def comm_metainfo_pass( + gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict, comm_actions_dict: Dict +) -> GraphModule: """ The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph. """ for node in gm.graph.nodes: if node.target == runtime_apply: - setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) + setattr(node, "best_strategy_info", _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) elif node.target == runtime_comm_spec_apply: - setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + setattr(node, "best_strategy_info", _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) else: pass return gm diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index 0673b767de7b..9b000549de6c 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -21,16 +21,15 @@ def _normalize_tuple(x): @compatibility(is_backward_compatible=False) class MetaInfoProp: - def __init__(self, module: GraphModule) -> None: self.module = module self.func_dict = { - 'placeholder': self.placeholder_handler, - 'get_attr': self.get_attr_handler, - 'output': self.output_handler, - 'call_function': self.node_handler, - 'call_module': self.node_handler, - 'call_method': self.node_handler, + "placeholder": self.placeholder_handler, + "get_attr": self.get_attr_handler, + "output": self.output_handler, + "call_function": self.node_handler, + "call_module": self.node_handler, + "call_method": self.node_handler, } def _set_data_ptr(self, x): @@ -46,7 +45,7 @@ def _is_inplace(self, node: Node): """ Check if the node is inplace operation. """ - if node.op == 'call_module': + if node.op == "call_module": return node.graph.owning_module.get_submodule(node.target).__class__ in OUTPUT_SAVED_MOD elif node.op == "call_function": return node.target in OUTPUT_SAVED_OPS @@ -66,7 +65,7 @@ def placeholder_handler(self, node: Node) -> None: Handle the placeholder node. """ graph_info = GraphInfo() - out = _normalize_tuple(getattr(node, '_meta_data', None)) + out = _normalize_tuple(getattr(node, "_meta_data", None)) graph_info.fwd_out = list(out) if out[0] is not None else [] node.meta = {**asdict(graph_info)} @@ -96,7 +95,7 @@ def node_handler(self, node: Node) -> None: """ Handle other kind of nodes """ - assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" + assert hasattr(node, "best_strategy_info"), f"Cannot find best_strategy_info in node {node}, {node.op}" graph_info = GraphInfo() meta_info = node.best_strategy_info meta_info: ShardMetaInfo @@ -126,7 +125,8 @@ def node_handler(self, node: Node) -> None: for tensor in par.meta.get("fwd_out", []): tensor: torch.Tensor target_input_tensor = next( - (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None) + (x for x in input_tensors if not x.data_ptr() and x.shape == tensor.shape), None + ) if target_input_tensor is not None: target_input_tensor.data_ptr = tensor.data_ptr diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 2049a06187d2..27afe72c0db8 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -1,18 +1,10 @@ -from copy import deepcopy from typing import Dict, List import torch from torch.fx.node import Node from colossalai._analyzer.fx.node_util import MetaInfo -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - OperationData, - OperationDataType, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType from colossalai.tensor.comm_spec import CommSpec from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec @@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec) -def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, - user_node_index: int): +def runtime_apply_for_iterable_object( + node: Node, origin_dict: Dict, input_dict: Dict, node_index: int, user_node_index: int +): """ This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list is converted into the user node expected form. """ rst = [] - for index, (origin_sharding_spec, - target_sharding_spec) in enumerate(zip(origin_dict[node_index], - input_dict[node_index][user_node_index])): + for index, (origin_sharding_spec, target_sharding_spec) in enumerate( + zip(origin_dict[node_index], input_dict[node_index][user_node_index]) + ): rst.append( - shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec, - target_sharding_spec)) + shape_consistency_manager.apply_for_autoparallel_runtime( + node[index], origin_sharding_spec, target_sharding_spec + ) + ) rst = type(node)(rst) return rst @@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_ if isinstance(comm_action.comm_spec, CommSpec): rst = comm_action.comm_spec.covert_spec_to_action(tensor) else: - origin_sharding_spec = comm_action.comm_spec['src_spec'] - tgt_sharding_spec = comm_action.comm_spec['tgt_spec'] + origin_sharding_spec = comm_action.comm_spec["src_spec"] + tgt_sharding_spec = comm_action.comm_spec["tgt_spec"] rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec) return rst @@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]): node_to_index_dict = {} index = 0 for node in nodes: - if node.target == 'sharding_spec_convert_dict': + if node.target == "sharding_spec_convert_dict": input_dict_node = node continue - if node.target == 'origin_node_sharding_spec_dict': + if node.target == "origin_node_sharding_spec_dict": origin_dict_node = node continue - if node.target == 'comm_actions_dict': + if node.target == "comm_actions_dict": comm_actions_dict_node = node continue - if not hasattr(node, 'best_strategy'): + if not hasattr(node, "best_strategy"): continue node_to_index_dict[node] = index index += 1 @@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): input_dict_node, origin_dict_node, _, node_to_index_dict = _preprocess_graph(nodes) for node in nodes: - if not hasattr(node, 'best_strategy') or node.op == 'output': + if not hasattr(node, "best_strategy") or node.op == "output": continue for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes): if isinstance(node.sharding_spec, (list, tuple)): assert isinstance( - node.target_sharding_specs, - (list, - tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list' + node.target_sharding_specs, (list, tuple) + ), "target sharding specs should be tuple or list when node.sharding_spec is tuple or list" total_difference = 0 - for sharding_spec, target_sharding_spec in zip(node.sharding_spec, - node.target_sharding_specs[user_node_index]): + for sharding_spec, target_sharding_spec in zip( + node.sharding_spec, node.target_sharding_specs[user_node_index] + ): total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec) if total_difference == 0: continue with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', - runtime_apply_for_iterable_object, - args=(node, origin_dict_node, input_dict_node, - node_to_index_dict[node], user_node_index)) + shape_consistency_node = mod_graph.create_node( + "call_function", + runtime_apply_for_iterable_object, + args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index), + ) else: - assert isinstance(node.sharding_spec, - ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.' + assert isinstance( + node.sharding_spec, ShardingSpec + ), "node.sharding_spec should be type of ShardingSpec, tuple or list." if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0: continue with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', - runtime_apply, - args=(node, origin_dict_node, input_dict_node, - node_to_index_dict[node], user_node_index)) - if hasattr(user_node.meta['info'], 'activation_checkpoint'): - MetaInfo(shape_consistency_node, - mod_dir=user_node.meta['info'].mod_dir, - activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint)) + shape_consistency_node = mod_graph.create_node( + "call_function", + runtime_apply, + args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index), + ) + if hasattr(user_node.meta["info"], "activation_checkpoint"): + MetaInfo( + shape_consistency_node, + mod_dir=user_node.meta["info"].mod_dir, + activation_checkpoint=tuple(user_node.meta["info"].activation_checkpoint), + ) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) # the origin node may be a positional argument or key word argument of user node @@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): _, _, comm_actions_dict_node, node_to_index_dict = _preprocess_graph(nodes) for node in nodes: - if not hasattr(node, 'best_strategy') or node.op == 'output': + if not hasattr(node, "best_strategy") or node.op == "output": continue comm_actions = node.best_strategy.communication_actions for op_data, comm_action in comm_actions.items(): - if comm_action.comm_type == CommType.HOOK: continue if comm_action.comm_type == CommType.BEFORE: @@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): else: comm_object = node.args[comm_action.arg_index] with mod_graph.inserting_before(node): - comm_spec_apply_node = mod_graph.create_node('call_function', - runtime_comm_spec_apply, - args=(comm_object, comm_actions_dict_node, - node_to_index_dict[node], op_data.name)) + comm_spec_apply_node = mod_graph.create_node( + "call_function", + runtime_comm_spec_apply, + args=(comm_object, comm_actions_dict_node, node_to_index_dict[node], op_data.name), + ) # the origin node may be a positional argument or key word argument of user node if comm_action.key_for_kwarg is not None: # substitute the origin node with comm_spec_apply_node @@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): elif comm_action.comm_type == CommType.AFTER: with mod_graph.inserting_after(node): - comm_spec_apply_node = mod_graph.create_node('call_function', - runtime_comm_spec_apply, - args=(node, comm_actions_dict_node, - node_to_index_dict[node], op_data.name)) + comm_spec_apply_node = mod_graph.create_node( + "call_function", + runtime_comm_spec_apply, + args=(node, comm_actions_dict_node, node_to_index_dict[node], op_data.name), + ) user_list = list(node.users.keys()) for user in user_list: if user == comm_spec_apply_node: @@ -211,10 +212,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): # substitute the origin node with comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node user.kwargs = new_kwargs - if hasattr(node.meta['info'], 'activation_checkpoint'): - MetaInfo(comm_spec_apply_node, - mod_dir=node.meta['info'].mod_dir, - activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) + if hasattr(node.meta["info"], "activation_checkpoint"): + MetaInfo( + comm_spec_apply_node, + mod_dir=node.meta["info"].mod_dir, + activation_checkpoint=tuple(node.meta["info"].activation_checkpoint), + ) return gm @@ -227,21 +230,21 @@ def _act_annotation_pass(gm: torch.fx.GraphModule): nodes = tuple(mod_graph.nodes) for node in nodes: - if not hasattr(node.meta, 'activation_checkpoint'): - from .runtime_preparation_pass import size_processing + if not hasattr(node.meta, "activation_checkpoint"): + pass user_act_annotation = -1 input_act_annotation = -1 for user_node in node.users.keys(): - if 'activation_checkpoint' in user_node.meta: - user_act_annotation = user_node.meta['activation_checkpoint'] + if "activation_checkpoint" in user_node.meta: + user_act_annotation = user_node.meta["activation_checkpoint"] break for input_node in node._input_nodes.keys(): - if 'activation_checkpoint' in input_node.meta: - input_act_annotation = input_node.meta['activation_checkpoint'] + if "activation_checkpoint" in input_node.meta: + input_act_annotation = input_node.meta["activation_checkpoint"] break if user_act_annotation == input_act_annotation and user_act_annotation != -1: - node.meta['activation_checkpoint'] = user_act_annotation + node.meta["activation_checkpoint"] = user_act_annotation return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 0ed0742ee57e..65c3d8e0cbeb 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -1,19 +1,12 @@ import operator -from copy import deepcopy from typing import Dict, List, Union import torch -from torch.fx import symbolic_trace from torch.fx.node import Node from colossalai._analyzer.fx.node_util import MetaInfo from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - OperationDataType, - ShardingStrategy, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommType, OperationDataType from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.comm_spec import _all_reduce @@ -25,11 +18,13 @@ shape_consistency_manager = ShapeConsistencyManager() -def size_processing(size: Union[int, torch.Size], - dim_partition_dict: Dict[int, List[int]], - device_mesh_info: Dict[int, int], - target_dim: int = None, - node_name: str = None): +def size_processing( + size: Union[int, torch.Size], + dim_partition_dict: Dict[int, List[int]], + device_mesh_info: Dict[int, int], + target_dim: int = None, + node_name: str = None, +): """ This method will be invoked during runtime to convert size node value depending on distributed information. """ @@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size], return size -def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], - strategies_constructor: StrategiesConstructor): +def solution_annotation_pass( + gm: torch.fx.GraphModule, solution: List[int], strategies_constructor: StrategiesConstructor +): """ This method is used to stick the solution strategy to the nodes and add the information required in runtime into graph as placeholder nodes. @@ -70,14 +66,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): strategies_vector = node.strategies_vector # stick the solution strategy to the corresponding node - setattr(node, 'best_strategy', strategies_vector[strategy_index]) - setattr(node, 'sharding_spec', strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) + setattr(node, "best_strategy", strategies_vector[strategy_index]) + setattr(node, "sharding_spec", strategies_vector[strategy_index].get_sharding_spec_by_name(str(node))) origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( - str(node)) + str(node) + ) # attach the corresponding metainfo if node has the attribute `strategies_info` - if hasattr(node, 'strategies_info'): - setattr(node, 'best_strategy_info', node.strategies_info[strategy_index]) + if hasattr(node, "strategies_info"): + setattr(node, "best_strategy_info", node.strategies_info[strategy_index]) # the dict to get input sharding specs of user node sharding_spec_convert_dict = {} @@ -92,15 +89,15 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name)) target_sharding_specs.append(target_sharding_spec) sharding_spec_convert_dict[index] = target_sharding_specs - setattr(node, 'target_sharding_specs', target_sharding_specs) + setattr(node, "target_sharding_specs", target_sharding_specs) # the get_attr node strategy is kind of pending strategy, which means we will change it # to the same strategy of the user node. - if node.op == 'get_attr': - assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' + if node.op == "get_attr": + assert len(target_sharding_specs) == 1, f"sharing weight is not supported in current version." target_node = node.strategies_vector.successor_nodes[0] node_name = str(node) - if target_node.op == 'call_function' and target_node.target in RESHAPE_FUNC_OP: + if target_node.op == "call_function" and target_node.target in RESHAPE_FUNC_OP: node_name = str(target_node) target_node = target_node.strategies_vector.successor_nodes[0] user_strategy = target_node.best_strategy @@ -122,11 +119,11 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], # add above dicts into graph for node in nodes: - if node.op != 'placeholder': + if node.op != "placeholder": with mod_graph.inserting_before(node): - input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') - origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') - comm_actions_dict_node = mod_graph.create_node('placeholder', target='comm_actions_dict') + input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict") + origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict") + comm_actions_dict_node = mod_graph.create_node("placeholder", target="comm_actions_dict") break return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict @@ -148,7 +145,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh device_mesh_info[dim] = dim_size def _extract_target_dim(node): - ''' + """ A helper function to extract the target dimension from size node. There are two usages of torch.Tensor.size: 1. tensor.size() @@ -156,7 +153,7 @@ def _extract_target_dim(node): If a target_dim is assigned, then the output will be in type of int, instead of torch.Size. Otherwise, the output will be in type of torch.Size and this function will return None. - ''' + """ target_dim = None if len(node.args) > 1: target_dim = node.args[1] @@ -165,19 +162,21 @@ def _extract_target_dim(node): return target_dim def _post_processing(node, size_processing_node): - ''' + """ This function is used to process the dependency between the size node and its users after inserting the size_process_node. - ''' + """ # store original node and processing node pair in node_pairs dictionary # It will be used to replace the original node with processing node in slice object node_pairs[node] = size_processing_node size_processing_node._meta_data = node._meta_data - if hasattr(node.meta['info'], 'activation_checkpoint'): - MetaInfo(size_processing_node, - mod_dir=node.meta['info'].mod_dir, - activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) + if hasattr(node.meta["info"], "activation_checkpoint"): + MetaInfo( + size_processing_node, + mod_dir=node.meta["info"].mod_dir, + activation_checkpoint=tuple(node.meta["info"].activation_checkpoint), + ) user_list = list(node.users.keys()) for user in user_list: @@ -196,10 +195,10 @@ def _post_processing(node, size_processing_node): user.kwargs = new_kwargs def _update_slice_object_args(slice_object): - ''' + """ This function is used to update the slice object argument list. If the slice object contains the Node argument, then the size node will be replaced with - ''' + """ if isinstance(slice_object, slice): start = slice_object.start stop = slice_object.stop @@ -220,8 +219,7 @@ def _update_slice_object_args(slice_object): raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}") for node in nodes: - - if node.op == 'call_method' and node.target == 'size': + if node.op == "call_method" and node.target == "size": # extract useful information from size node # dim_partition_dict will instruct the size value on which # dimension should be enlarged. @@ -232,14 +230,14 @@ def _update_slice_object_args(slice_object): # insert size_processing node with mod_graph.inserting_after(node): - size_processing_node = mod_graph.create_node('call_function', - size_processing, - args=(node, dim_partition_dict, device_mesh_info, - target_dim, node.name)) + size_processing_node = mod_graph.create_node( + "call_function", + size_processing, + args=(node, dim_partition_dict, device_mesh_info, target_dim, node.name), + ) _post_processing(node, size_processing_node) - if node.op == 'call_function' and node.target == operator.getitem: - + if node.op == "call_function" and node.target == operator.getitem: getitem_index = node.args[1] # slice object is quite special in torch.fx graph, # On one side, we treat slice object same as type of int, @@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh) nodes = tuple(mod_graph.nodes) def _extract_info_from_sharding_spec(sharding_spec): - ''' + """ This function is used to extract the dim_partition_dict and device_mesh from sharding spec instance or a list of sharding spec. - ''' + """ if isinstance(sharding_spec, ShardingSpec): dim_partition_dict = sharding_spec.dim_partition_dict device_mesh = sharding_spec.device_mesh return dim_partition_dict, device_mesh if sharding_spec is None: return None, None - assert isinstance(sharding_spec, - (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None' + assert isinstance( + sharding_spec, (tuple, list) + ), "sharding_spec should be type of ShardingSpec, tuple, list or None" device_mesh = sharding_spec[0].device_mesh dim_partition_dict = [] @@ -322,8 +321,9 @@ def _process_node_arguments(node): else: new_args.append(arg) else: - assert isinstance(arg, - (int, tuple, list)), 'The argument in view node should be either type of Node or int.' + assert isinstance( + arg, (int, tuple, list) + ), "The argument in view node should be either type of Node or int." if isinstance(arg, (tuple, list)): new_args.extend(arg) else: @@ -332,7 +332,7 @@ def _process_node_arguments(node): def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): new_args = _process_node_arguments(node) - if node.op == 'call_method': + if node.op == "call_method": args_to_process = list(new_args[1:]) else: args_to_process = list(new_args) @@ -350,7 +350,7 @@ def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): args_to_process = tuple(args_to_process) - if node.op == 'call_method': + if node.op == "call_method": new_args = (new_args[0],) + args_to_process else: new_args = args_to_process @@ -358,9 +358,9 @@ def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node): node.args = new_args def _filter_node_with_shape_args(node): - if node.op == 'call_method': + if node.op == "call_method": target = getattr(node.args[0]._meta_data.__class__, node.target) - elif node.op == 'call_function': + elif node.op == "call_function": target = node.target else: target = None @@ -371,7 +371,7 @@ def _filter_node_with_shape_args(node): for node in nodes: # skip the placeholder node added in _solution_annotation pass - if not hasattr(node, 'sharding_spec'): + if not hasattr(node, "sharding_spec"): continue output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec) @@ -392,15 +392,21 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes reduction_stream = torch.cuda.Stream() def _add_hook_for_grad_communication(node, param, name=None): - comm_actions = node.best_strategy.communication_actions def _filter_param_to_hook(node, op_data, comm_action, name): - - if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK: + if ( + node.op == "call_module" + and op_data.type == OperationDataType.PARAM + and op_data.name == name + and comm_action.comm_type == CommType.HOOK + ): return True - if node.op == 'get_attr' and isinstance( - node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK: + if ( + node.op == "get_attr" + and isinstance(node._meta_data, torch.nn.parameter.Parameter) + and comm_action.comm_type == CommType.HOOK + ): return True return False @@ -410,7 +416,6 @@ def _filter_param_to_hook(node, op_data, comm_action, name): if _filter_param_to_hook(node, operation_data, comm_action, name=name): def wrapper(param, comm_spec, stream, overlap): - def hook_fn(grad): if overlap: with torch.cuda.stream(stream): @@ -426,22 +431,26 @@ def _shard_param(param, target_sharding_spec): # apply the sharding spec of parameters if target_sharding_spec.dim_partition_dict != {}: origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {}) - setattr(param, 'sharding_spec', origin_sharding_spec) + setattr(param, "sharding_spec", origin_sharding_spec) # TODO: build a ColoParameter class to manager the distributed parameters # we could use .data here, because all the operations just happen before the real training # loop, so we don't need to track these operations in the autograd graph. param = torch.nn.Parameter( - shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, - target_sharding_spec).detach().clone()) + shape_consistency_manager.apply_for_autoparallel_runtime( + param.data, param.sharding_spec, target_sharding_spec + ) + .detach() + .clone() + ) return param for node in nodes: - if node.op == 'call_module': + if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) # TODO: we need to do more actions to take care of the shared parameters. - if hasattr(target_module, 'processed') and target_module.processed: + if hasattr(target_module, "processed") and target_module.processed: continue - setattr(target_module, 'processed', True) + setattr(target_module, "processed", True) for name, param in target_module.named_parameters(): target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) param = _shard_param(param, target_sharding_spec) @@ -453,7 +462,7 @@ def _shard_param(param, target_sharding_spec): # apply the sharding spec of buffers for name, buffer in target_module.named_buffers(): origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {}) - setattr(buffer, 'sharding_spec', origin_sharding_spec) + setattr(buffer, "sharding_spec", origin_sharding_spec) target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec) sharded_buffer_dict[name] = buffer_sharded @@ -461,7 +470,7 @@ def _shard_param(param, target_sharding_spec): for name, buffer_sharded in sharded_buffer_dict.items(): setattr(target_module, name, buffer_sharded.detach().clone()) - if node.op == 'get_attr': + if node.op == "get_attr": root = node.graph.owning_module atoms = node.target.split(".") attr_len = len(atoms) @@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule): """ replace the origin kernel into kernel with implicit communication inside. """ - pass -def runtime_preparation_pass(gm: torch.fx.GraphModule, - solution: List[int], - device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor, - overlap=False): +def runtime_preparation_pass( + gm: torch.fx.GraphModule, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor, + overlap=False, +): gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass( - gm, solution, strategies_constructor) + gm, solution, strategies_constructor + ) gm = size_value_converting_pass(gm, device_mesh) gm = node_args_converting_pass(gm, device_mesh) # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. diff --git a/colossalai/auto_parallel/tensor_shard/constants.py b/colossalai/auto_parallel/tensor_shard/constants.py index 99c124934060..e9c2c8664a61 100644 --- a/colossalai/auto_parallel/tensor_shard/constants.py +++ b/colossalai/auto_parallel/tensor_shard/constants.py @@ -3,9 +3,22 @@ import torch __all__ = [ - 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', - 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', - 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST' + "ELEMENTWISE_MODULE_OP", + "ELEMENTWISE_FUNC_OP", + "RESHAPE_FUNC_OP", + "CONV_MODULE_OP", + "CONV_FUNC_OP", + "LINEAR_MODULE_OP", + "LINEAR_FUNC_OP", + "BATCHNORM_MODULE_OP", + "POOL_MODULE_OP", + "NON_PARAM_FUNC_OP", + "BCAST_FUNC_OP", + "EMBEDDING_MODULE_OP", + "LAYERNORM_MODULE_OP", + "ELEMENTWISE_METHOD_OP", + "RESHAPE_METHOD_OP", + "INFINITY_COST", ] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] @@ -18,13 +31,13 @@ torch.nn.functional.relu, torch.nn.functional.dropout, # softmax should not be here - torch.nn.functional.softmax + torch.nn.functional.softmax, ] ELEMENTWISE_METHOD_OP = [ torch.Tensor.to, torch.Tensor.type, # TODO: contiguous maybe need some extra processes. - torch.Tensor.contiguous + torch.Tensor.contiguous, ] RESHAPE_FUNC_OP = [ torch.flatten, @@ -42,15 +55,36 @@ torch.Tensor.transpose, ] BCAST_FUNC_OP = [ - torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, - operator.mul, operator.floordiv, operator.truediv, torch.matmul, operator.pow, torch.pow + torch.add, + torch.sub, + torch.mul, + torch.div, + torch.floor_divide, + torch.true_divide, + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + torch.matmul, + operator.pow, + torch.pow, ] CONV_MODULE_OP = [ - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, - torch.nn.ConvTranspose3d + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, ] CONV_FUNC_OP = [ - torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d + torch.conv1d, + torch.conv2d, + torch.conv3d, + torch.conv_transpose1d, + torch.conv_transpose2d, + torch.conv_transpose3d, ] EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding] LINEAR_MODULE_OP = [torch.nn.Linear] @@ -85,7 +119,7 @@ operator.floordiv, operator.truediv, # softmax should not be here - torch.nn.functional.softmax + torch.nn.functional.softmax, ] INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index b406ca6fb7e0..d82f0ef53f66 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch.fx import GraphModule from torch.fx.graph import Graph from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen @@ -14,27 +13,32 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec class ModuleWrapper(nn.Module): - ''' + """ This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict into the forward function. - ''' - - def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]], - origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]): - ''' + """ + + def __init__( + self, + module: ColoGraphModule, + sharding_spec_dict: Dict[int, List[ShardingSpec]], + origin_spec_dict: Dict[int, ShardingSpec], + comm_actions_dict: Dict[int, Dict[str, CommAction]], + ): + """ Args: module: the original module sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node. origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor. comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor. - ''' + """ super(ModuleWrapper, self).__init__() self.module = module self.sharding_spec_dict = sharding_spec_dict @@ -42,67 +46,68 @@ def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[S self.comm_actions_dict = comm_actions_dict def forward(self, *args, **kwargs): - return self.module(*args, - sharding_spec_convert_dict=self.sharding_spec_dict, - origin_node_sharding_spec_dict=self.origin_spec_dict, - comm_actions_dict=self.comm_actions_dict, - **kwargs) + return self.module( + *args, + sharding_spec_convert_dict=self.sharding_spec_dict, + origin_node_sharding_spec_dict=self.origin_spec_dict, + comm_actions_dict=self.comm_actions_dict, + **kwargs, + ) def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable): - ''' + """ This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func. - ''' + """ # TODO: implement this function - pass def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]): - ''' + """ This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape from the alpha_beta_dict. These two values will be used to estimate the communication cost. - ''' + """ # TODO: implement this function - pass -def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, - shard_option: str): - ''' +def build_strategy_constructor( + graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str +): + """ This method is used to build the strategy_constructor for the given graph. After this method, each node in the graph will have a strategies_vector which is constructed by the related node handler. - ''' - if solver_preference == 'standard': + """ + if solver_preference == "standard": solver_preference = SolverPerference.STANDARD - elif solver_preference == 'tp': + elif solver_preference == "tp": solver_preference = SolverPerference.TP - elif solver_preference == 'dp': + elif solver_preference == "dp": solver_preference = SolverPerference.DP else: - raise ValueError(f'Invalid solver_preference: {solver_preference}') + raise ValueError(f"Invalid solver_preference: {solver_preference}") - if dataloader_option == 'replicated': + if dataloader_option == "replicated": dataloader_option = DataloaderOption.REPLICATED - elif dataloader_option == 'distributed': + elif dataloader_option == "distributed": dataloader_option = DataloaderOption.DISTRIBUTED else: - raise ValueError(f'Invalid dataloader_option: {dataloader_option}') + raise ValueError(f"Invalid dataloader_option: {dataloader_option}") - if shard_option == 'standard': + if shard_option == "standard": shard_option = ShardOption.STANDARD - elif shard_option == 'shard': + elif shard_option == "shard": shard_option = ShardOption.SHARD - elif shard_option == 'shard_last_axis': + elif shard_option == "shard_last_axis": shard_option = ShardOption.SHARD_LAST_AXIS - elif shard_option == 'full_shard': + elif shard_option == "full_shard": shard_option = ShardOption.FULL_SHARD else: - raise ValueError(f'Invalid shard_option: {shard_option}') + raise ValueError(f"Invalid shard_option: {shard_option}") - solver_options = SolverOptions(solver_perference=solver_preference, - dataloader_option=dataloader_option, - shard_option=shard_option) + solver_options = SolverOptions( + solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option + ) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -110,10 +115,10 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_pre def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0): - ''' + """ This method is used to solve the best solution for the given graph. The solution is a list of integers, each integer represents the best strategy index of the corresponding node. - ''' + """ # temporarily we use all nodes as liveness list, we count the backward memory cost together with # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. # graph_analyser = GraphAnalyser(gm) @@ -127,23 +132,23 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc return solution -def transform_to_sharded_model(gm: ColoGraphModule, - meta_args: Dict, - solution: List[int], - device_mesh: DeviceMesh, - strategies_constructor: StrategiesConstructor, - overlap: bool = False): - ''' +def transform_to_sharded_model( + gm: ColoGraphModule, + meta_args: Dict, + solution: List[int], + device_mesh: DeviceMesh, + strategies_constructor: StrategiesConstructor, + overlap: bool = False, +): + """ This method is used to transform the original graph to the sharded graph. The model parameters will be sharded according to the solution and the grad hooks will be added to the sharded graph using the runtime_preparation_pass. The communication node will be added into the graph using the runtime_apply_pass. - ''' - gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, - solution, - device_mesh, - strategies_constructor, - overlap=overlap) + """ + gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( + gm, solution, device_mesh, strategies_constructor, overlap=overlap + ) gm = runtime_apply_pass(gm) shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict) gm.recompile() @@ -152,12 +157,14 @@ def transform_to_sharded_model(gm: ColoGraphModule, return gm, sharding_spec_dicts -def initialize_device_mesh(world_size: int = -1, - physical_devices: List[int] = None, - alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, - logical_mesh_shape: Tuple[int] = None, - logical_mesh_id: torch.Tensor = None): - ''' +def initialize_device_mesh( + world_size: int = -1, + physical_devices: List[int] = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None, +): + """ This method is used to initialize the device mesh. Args: @@ -170,7 +177,7 @@ def initialize_device_mesh(world_size: int = -1, logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical mesh shape. logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. - ''' + """ # if world_size is not set, use the world size from torch.distributed if world_size == -1: world_size = dist.get_world_size() @@ -201,27 +208,31 @@ def initialize_device_mesh(world_size: int = -1, # extract alpha and beta values for the chosen logical mesh shape mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id) - device_mesh = DeviceMesh(physical_mesh_id=physical_mesh, - logical_mesh_id=logical_mesh_id, - mesh_alpha=mesh_alpha, - mesh_beta=mesh_beta, - init_process_group=True) + device_mesh = DeviceMesh( + physical_mesh_id=physical_mesh, + logical_mesh_id=logical_mesh_id, + mesh_alpha=mesh_alpha, + mesh_beta=mesh_beta, + init_process_group=True, + ) return device_mesh -def initialize_model(model: nn.Module, - meta_args: Dict[str, torch.Tensor], - device_mesh: DeviceMesh, - memory_budget: float = -1.0, - overlap: bool = False, - solver_preference: str = 'standard', - dataloader_option: str = 'replicated', - shard_option: str = 'standard', - save_solver_solution: bool = False, - load_solver_solution: bool = False, - solution_path: str = None, - return_solution: bool = False): - ''' +def initialize_model( + model: nn.Module, + meta_args: Dict[str, torch.Tensor], + device_mesh: DeviceMesh, + memory_budget: float = -1.0, + overlap: bool = False, + solver_preference: str = "standard", + dataloader_option: str = "replicated", + shard_option: str = "standard", + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solution_path: str = None, + return_solution: bool = False, +): + """ This method is used to initialize the sharded model which could be used as normal pytorch model. Args: @@ -246,7 +257,7 @@ def initialize_model(model: nn.Module, return_solution(optional): if the return_solution is True, the solution will be returned. The returned solution will be used to debug or help to analyze the sharding result. Therefore, we will not just return a series of integers, but return the best strategies. - ''' + """ tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) @@ -256,11 +267,13 @@ def initialize_model(model: nn.Module, shape_prop_pass(gm, *meta_args.values()) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, - device_mesh, - solver_preference=solver_preference, - dataloader_option=dataloader_option, - shard_option=shard_option) + strategies_constructor = build_strategy_constructor( + graph, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option, + ) if load_solver_solution: solution = torch.load(solution_path) else: @@ -268,8 +281,9 @@ def initialize_model(model: nn.Module, if save_solver_solution: torch.save(solution, solution_path) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor, - overlap) + gm, sharding_spec_dicts = transform_to_sharded_model( + gm, meta_args, solution, device_mesh, strategies_constructor, overlap + ) model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) @@ -277,28 +291,30 @@ def initialize_model(model: nn.Module, solution_to_return = [] nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] for index, node in enumerate(nodes): - solution_to_return.append(f'{node.name} {node.strategies_vector[solution[index]].name}') + solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}") return model_to_return, solution_to_return else: return model_to_return -def autoparallelize(model: nn.Module, - meta_args: Dict[str, torch.Tensor] = None, - data_loader: torch.utils.data.DataLoader = None, - data_process_func: callable = None, - alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, - logical_mesh_shape: Tuple[int] = None, - logical_mesh_id: torch.Tensor = None, - solver_preference: str = 'standard', - dataloader_option: str = 'replicated', - shard_option: str = 'standard', - save_solver_solution: bool = False, - load_solver_solution: bool = False, - solver_solution_path: str = None, - return_solution: bool = False, - memory_budget: float = -1.0): - ''' +def autoparallelize( + model: nn.Module, + meta_args: Dict[str, torch.Tensor] = None, + data_loader: torch.utils.data.DataLoader = None, + data_process_func: callable = None, + alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, + logical_mesh_shape: Tuple[int] = None, + logical_mesh_id: torch.Tensor = None, + solver_preference: str = "standard", + dataloader_option: str = "replicated", + shard_option: str = "standard", + save_solver_solution: bool = False, + load_solver_solution: bool = False, + solver_solution_path: str = None, + return_solution: bool = False, + memory_budget: float = -1.0, +): + """ This method is used to initialize the device mesh, extract the meta_args, and use them to create a sharded model. @@ -329,24 +345,26 @@ def autoparallelize(model: nn.Module, return_solution(optional): if the return_solution is True, the solution will be returned. memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0, the memory budget will be infinity. - ''' - device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, - logical_mesh_shape=logical_mesh_shape, - logical_mesh_id=logical_mesh_id) + """ + device_mesh = initialize_device_mesh( + alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id + ) if meta_args is None: meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func) - rst_to_unpack = initialize_model(model, - meta_args, - device_mesh, - solver_preference=solver_preference, - dataloader_option=dataloader_option, - shard_option=shard_option, - save_solver_solution=save_solver_solution, - load_solver_solution=load_solver_solution, - solution_path=solver_solution_path, - return_solution=return_solution, - memory_budget=memory_budget) + rst_to_unpack = initialize_model( + model, + meta_args, + device_mesh, + solver_preference=solver_preference, + dataloader_option=dataloader_option, + shard_option=shard_option, + save_solver_solution=save_solver_solution, + load_solver_solution=load_solver_solution, + solution_path=solver_solution_path, + return_solution=return_solution, + memory_budget=memory_budget, + ) if return_solution: model, solution = rst_to_unpack diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index 9903ca54e52c..aa2e5e9c40c0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -25,11 +25,33 @@ from .where_handler import WhereHandler __all__ = [ - 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', - 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', - 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', - 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', - 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', - 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler', - 'SplitHandler' + "LinearFunctionHandler", + "LinearModuleHandler", + "BMMFunctionHandler", + "AddBMMFunctionHandler", + "LayerNormModuleHandler", + "BatchNormModuleHandler", + "ConvModuleHandler", + "ConvFunctionHandler", + "UnaryElementwiseHandler", + "DefaultReshapeHandler", + "PlaceholderHandler", + "OutputHandler", + "WhereHandler", + "NormPoolingHandler", + "BinaryElementwiseHandler", + "MatMulHandler", + "operator_registry", + "ADDMMFunctionHandler", + "GetItemHandler", + "GetattrHandler", + "ViewHandler", + "PermuteHandler", + "TensorConstructorHandler", + "EmbeddingModuleHandler", + "EmbeddingFunctionHandler", + "SumHandler", + "SoftmaxHandler", + "TransposeHandler", + "SplitHandler", ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py index da0d199c5e05..47c654d6aa43 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py @@ -2,15 +2,13 @@ import torch -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager - -from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator -__all__ = ['ADDMMFunctionHandler'] +__all__ = ["ADDMMFunctionHandler"] @operator_registry.register(torch.addmm) @@ -30,25 +28,26 @@ def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType: return data_type def get_operation_data_mapping(self) -> Dict[str, OperationData]: - # input operand input_data = self.node.args[1]._meta_data - physical_input_operand = OperationData(name=str(self.node.args[1]), - type=self._infer_op_data_type(input_data), - data=input_data) + physical_input_operand = OperationData( + name=str(self.node.args[1]), type=self._infer_op_data_type(input_data), data=input_data + ) # other operand other_data = self.node.args[2]._meta_data - physical_other_operand = OperationData(name=str(self.node.args[2]), - type=self._infer_op_data_type(other_data), - data=other_data) + physical_other_operand = OperationData( + name=str(self.node.args[2]), type=self._infer_op_data_type(other_data), data=other_data + ) # bias physical shape bias_logical_shape = self.node._meta_data.shape bias_data = self.node.args[0]._meta_data - physical_bias_operand = OperationData(name=str(self.node.args[0]), - type=self._infer_op_data_type(bias_data), - data=bias_data, - logical_shape=bias_logical_shape) + physical_bias_operand = OperationData( + name=str(self.node.args[0]), + type=self._infer_op_data_type(bias_data), + data=bias_data, + logical_shape=bias_logical_shape, + ) # output physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) @@ -57,7 +56,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: "input": physical_input_operand, "other": physical_other_operand, "output": physical_output, - 'bias': physical_bias_operand + "bias": physical_bias_operand, } return mapping @@ -66,26 +65,27 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm')) + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="addmm") + ) return generators def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: # convert bias from its logical sharding spec to its physical sharding spec op_data_mapping = self.get_operation_data_mapping() - bias_op_data = op_data_mapping['bias'] + bias_op_data = op_data_mapping["bias"] bias_physical_shape = bias_op_data.data.shape bias_logical_shape = bias_op_data.logical_shape bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - bias_sharding_spec, bias_logical_shape, bias_physical_shape) + bias_sharding_spec, bias_logical_shape, bias_physical_shape + ) strategy.sharding_specs[bias_op_data] = bias_sharding_spec if len(removed_dims) > 0: - comm_action = comm_actions_for_oprands(node=self.node, - removed_dims=removed_dims, - op_data=bias_op_data, - sharding_spec=bias_sharding_spec) + comm_action = comm_actions_for_oprands( + node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec + ) strategy.communication_actions[bias_op_data] = comm_action return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py index cb1bb36b7879..df4b1d6cef3f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -2,12 +2,12 @@ import torch -from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector -from .node_handler import MetaInfoModuleHandler, ModuleHandler +from ..sharding_strategy import OperationData, OperationDataType +from .node_handler import MetaInfoModuleHandler from .registry import operator_registry from .strategy import BatchNormStrategyGenerator, StrategyGenerator -__all__ = ['BatchNormModuleHandler'] +__all__ = ["BatchNormModuleHandler"] @operator_registry.register(torch.nn.BatchNorm1d) @@ -27,30 +27,37 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=self.named_parameters['weight'].shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=self.named_parameters["weight"].shape, + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) - physical_running_mean_operand = OperationData(name="running_mean", - type=OperationDataType.BUFFER, - data=self.named_buffers['running_mean'], - logical_shape=self.named_buffers['running_mean'].shape) + physical_running_mean_operand = OperationData( + name="running_mean", + type=OperationDataType.BUFFER, + data=self.named_buffers["running_mean"], + logical_shape=self.named_buffers["running_mean"].shape, + ) - physical_running_var_operand = OperationData(name="running_var", - type=OperationDataType.BUFFER, - data=self.named_buffers['running_var'], - logical_shape=self.named_buffers['running_var'].shape) + physical_running_var_operand = OperationData( + name="running_var", + type=OperationDataType.BUFFER, + data=self.named_buffers["running_var"], + logical_shape=self.named_buffers["running_var"].shape, + ) physical_num_batches_tracked_operand = OperationData( name="num_batches_tracked", type=OperationDataType.BUFFER, - data=self.named_buffers['num_batches_tracked'], - logical_shape=self.named_buffers['num_batches_tracked'].shape) + data=self.named_buffers["num_batches_tracked"], + logical_shape=self.named_buffers["num_batches_tracked"].shape, + ) mapping = { "input": physical_input_operand, @@ -58,12 +65,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: "output": physical_output, "running_mean": physical_running_mean_operand, "running_var": physical_running_var_operand, - "num_batches_tracked": physical_num_batches_tracked_operand + "num_batches_tracked": physical_num_batches_tracked_operand, } - if self.named_parameters['bias'] is not None: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + if self.named_parameters["bias"] is not None: + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py index db8f0b54ddee..f8c137348353 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py @@ -4,15 +4,14 @@ from torch.fx.node import Node from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, ShardingStrategy -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from ..constants import BCAST_FUNC_OP from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import BinaryElementwiseStrategyGenerator, StrategyGenerator -__all__ = ['BinaryElementwiseHandler'] +__all__ = ["BinaryElementwiseHandler"] @operator_registry.register(BCAST_FUNC_OP) @@ -38,7 +37,7 @@ def _get_arg_value(idx): # The meta_data of node type argument could also possibly be a non-tensor object. if not isinstance(meta_data, torch.Tensor): assert isinstance(meta_data, (int, float)) - meta_data = torch.Tensor([meta_data]).to('meta') + meta_data = torch.Tensor([meta_data]).to("meta") non_tensor = True else: @@ -46,7 +45,7 @@ def _get_arg_value(idx): # but we can deem it as meta data # as it won't affect the strategy generation assert isinstance(self.node.args[idx], (int, float)) - meta_data = torch.Tensor([self.node.args[idx]]).to('meta') + meta_data = torch.Tensor([self.node.args[idx]]).to("meta") non_tensor = True return meta_data, non_tensor @@ -58,24 +57,27 @@ def _get_arg_value(idx): # and filter the non-tensor op_data in post_process. self.non_tensor_list = [] # assert False - input_op_data = OperationData(name=str(self.node.args[0]), - type=_get_op_data_type(input_meta_data), - data=input_meta_data, - logical_shape=bcast_shape) - other_op_data = OperationData(name=str(self.node.args[1]), - type=_get_op_data_type(other_meta_data), - data=other_meta_data, - logical_shape=bcast_shape) - output_op_data = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_meta_data, - logical_shape=bcast_shape) + input_op_data = OperationData( + name=str(self.node.args[0]), + type=_get_op_data_type(input_meta_data), + data=input_meta_data, + logical_shape=bcast_shape, + ) + other_op_data = OperationData( + name=str(self.node.args[1]), + type=_get_op_data_type(other_meta_data), + data=other_meta_data, + logical_shape=bcast_shape, + ) + output_op_data = OperationData( + name=str(self.node), type=OperationDataType.OUTPUT, data=output_meta_data, logical_shape=bcast_shape + ) if non_tensor_input: self.non_tensor_list.append(input_op_data) if non_tensor_other: self.non_tensor_list.append(other_op_data) - mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data} return mapping def get_strategy_generator(self) -> List[StrategyGenerator]: @@ -100,14 +102,14 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li logical_shape = op_data.logical_shape sharding_spec = strategy.get_sharding_spec_by_name(op_data.name) sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - sharding_spec, logical_shape, physical_shape) + sharding_spec, logical_shape, physical_shape + ) strategy.sharding_specs[op_data] = sharding_spec if len(removed_dims) > 0: - comm_action = comm_actions_for_oprands(node=self.node, - removed_dims=removed_dims, - op_data=op_data, - sharding_spec=sharding_spec) + comm_action = comm_actions_for_oprands( + node=self.node, removed_dims=removed_dims, op_data=op_data, sharding_spec=sharding_spec + ) strategy.communication_actions[op_data] = comm_action return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py index da2b733c9f7a..5c22ac7bef11 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py @@ -2,15 +2,13 @@ import torch -from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager - -from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import BatchedMatMulStrategyGenerator, StrategyGenerator -__all__ = ['BMMFunctionHandler', 'AddBMMFunctionHandler'] +__all__ = ["BMMFunctionHandler", "AddBMMFunctionHandler"] def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): @@ -19,14 +17,14 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): node handler to reduce code redundancy. """ # input operand - physical_input_operand = OperationData(name=str(node.args[input_idx]), - type=OperationDataType.ARG, - data=node.args[input_idx]._meta_data) + physical_input_operand = OperationData( + name=str(node.args[input_idx]), type=OperationDataType.ARG, data=node.args[input_idx]._meta_data + ) # other operand - physical_other_operand = OperationData(name=str(node.args[other_idx]), - type=OperationDataType.ARG, - data=node.args[other_idx]._meta_data) + physical_other_operand = OperationData( + name=str(node.args[other_idx]), type=OperationDataType.ARG, data=node.args[other_idx]._meta_data + ) # output physical_output = OperationData(name=str(node), type=OperationDataType.OUTPUT, data=node._meta_data) @@ -35,11 +33,13 @@ def _get_data_mapping_for_bmm_op(node, input_idx, other_idx, bias_idx=None): if bias_idx is not None: # bias physical shape bias_logical_shape = node._meta_data.shape - physical_bias_operand = OperationData(name=str(node.args[bias_idx]), - type=OperationDataType.ARG, - data=node.args[bias_idx]._meta_data, - logical_shape=bias_logical_shape) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name=str(node.args[bias_idx]), + type=OperationDataType.ARG, + data=node.args[bias_idx]._meta_data, + logical_shape=bias_logical_shape, + ) + mapping["bias"] = physical_bias_operand return mapping @@ -91,20 +91,20 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li # convert bias from its logical sharding spec to its physical sharding spec op_data_mapping = self.get_operation_data_mapping() - if 'bias' in op_data_mapping: - bias_op_data = op_data_mapping['bias'] + if "bias" in op_data_mapping: + bias_op_data = op_data_mapping["bias"] bias_physical_shape = bias_op_data.data.shape bias_logical_shape = bias_op_data.logical_shape bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name) bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - bias_sharding_spec, bias_logical_shape, bias_physical_shape) + bias_sharding_spec, bias_logical_shape, bias_physical_shape + ) strategy.sharding_specs[bias_op_data] = bias_sharding_spec if len(removed_dims) > 0: - comm_action = comm_actions_for_oprands(node=self.node, - removed_dims=removed_dims, - op_data=bias_op_data, - sharding_spec=bias_sharding_spec) + comm_action = comm_actions_for_oprands( + node=self.node, removed_dims=removed_dims, op_data=bias_op_data, sharding_spec=bias_sharding_spec + ) strategy.communication_actions[bias_op_data] = comm_action return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py index 272b1c85630a..fd7c1f837a5a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py @@ -3,13 +3,13 @@ import torch import torch.nn.functional as F -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import transpose_partition_dim -from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler from .registry import operator_registry from .strategy import ConvStrategyGenerator, StrategyGenerator -__all__ = ['ConvModuleHandler', 'ConvFunctionHandler'] +__all__ = ["ConvModuleHandler", "ConvFunctionHandler"] @operator_registry.register(torch.nn.Conv1d) @@ -29,25 +29,29 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) logical_shape_for_weight = list(self.named_parameters["weight"].shape) - logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ - 1], logical_shape_for_weight[0] - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=torch.Size(logical_shape_for_weight)) + logical_shape_for_weight[0], logical_shape_for_weight[1] = ( + logical_shape_for_weight[1], + logical_shape_for_weight[0], + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=torch.Size(logical_shape_for_weight), + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} if "bias" in self.named_parameters: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy): @@ -77,9 +81,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -88,26 +92,30 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: data_type = OperationDataType.ARG logical_shape_for_weight = list(self.node.args[1]._meta_data.shape) - logical_shape_for_weight[0], logical_shape_for_weight[1] = logical_shape_for_weight[ - 1], logical_shape_for_weight[0] - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=data_type, - data=self.node.args[1]._meta_data, - logical_shape=torch.Size(logical_shape_for_weight)) + logical_shape_for_weight[0], logical_shape_for_weight[1] = ( + logical_shape_for_weight[1], + logical_shape_for_weight[0], + ) + physical_other_operand = OperationData( + name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=torch.Size(logical_shape_for_weight), + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None: + if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None: # check if the other operand is a parameter if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): data_type = OperationDataType.PARAM else: data_type = OperationDataType.ARG - physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), - type=data_type, - data=self.node.kwargs["bias"]._meta_data) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py index 0c5b9f39e1fb..feb1032a6c0f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py @@ -3,11 +3,11 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import DefaultReshapeGenerator, StrategyGenerator -__all__ = ['DefaultReshapeHandler'] +__all__ = ["DefaultReshapeHandler"] @operator_registry.register(torch.flatten) @@ -54,17 +54,15 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: input_data = self.node.args[0]._meta_data input_logical_shape = self.infer_logical_shape(input_data) - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=data_type, - data=input_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=data_type, data=input_data, logical_shape=input_logical_shape + ) output_data = self.node._meta_data output_logical_shape = self.infer_logical_shape(output_data) - physical_output = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_data, - logical_shape=output_logical_shape) + physical_output = OperationData( + name=str(self.node), type=OperationDataType.OUTPUT, data=output_data, logical_shape=output_logical_shape + ) mapping = {"input": physical_input_operand, "output": physical_output} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py index 112ee194b4ec..f29c3a0b7d5d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py @@ -12,11 +12,12 @@ from .registry import operator_registry from .strategy import EmbeddingStrategyGenerator, StrategyGenerator -__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler'] +__all__ = ["EmbeddingModuleHandler", "EmbeddingFunctionHandler"] -def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str, - output_name: str) -> List[ShardingStrategy]: +def _convert_logical_sharding_to_physical_sharding_spec_for_embedding( + strategy: ShardingStrategy, input_name: str, output_name: str +) -> List[ShardingStrategy]: """ This function converts the logical sharding spec to the physical sharding spec for both the input and output of the embedding operation. @@ -56,27 +57,31 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) try: # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={0: i}, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping={0: i}, + physical_shape=input_op_data.data.shape, + inplace=True, + ) if last_logical_output_dims in output_sharding_spec.dim_partition_dict: dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims} else: dim_mapping = {0: i} - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) - strategy_copy.name = f'{strategy.name}_{i}' + strategy_copy.name = f"{strategy.name}_{i}" sharding_strategies.append(strategy_copy) except ShardingNotDivisibleError as e: logger.debug( - f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}" ) else: # the generated sharding strategy does not shard the non-matrix dimension, @@ -87,20 +92,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name) # after updating, the logical shape will be replaced by the physical shape - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping={}, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, dim_mapping={}, physical_shape=input_op_data.data.shape, inplace=True + ) if last_logical_output_dims in output_sharding_spec.dim_partition_dict: dim_mapping = {last_logical_output_dims: last_physical_output_dims} else: dim_mapping = {} - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) sharding_strategies.append(strategy_copy) return sharding_strategies @@ -125,14 +131,16 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # Finally, the input will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=input_meta_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape, + ) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight']) + physical_other_operand = OperationData( + name="weight", type=OperationDataType.PARAM, data=self.named_parameters["weight"] + ) # Same as input, in nn.Embedding operation, all the dimensions of output will be treated as # (batch dimension, embedding dimension), and then the sharding spec will be generated based @@ -141,10 +149,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # Finally, the output will be transformed back to its original shape in self.post_process output_meta_data = self.node._meta_data output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape - physical_output = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_meta_data, - logical_shape=output_logical_shape) + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape, + ) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} @@ -157,10 +167,9 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li # create multiple sharding strategies for the inputs # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, - input_name=str( - self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies @@ -183,10 +192,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # Finally, the input will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape, + ) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -194,9 +205,9 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: else: data_type = OperationDataType.ARG - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=data_type, - data=self.node.args[1]._meta_data) + physical_other_operand = OperationData( + name=str(self.node.args[1]), type=data_type, data=self.node.args[1]._meta_data + ) # Same as input, in F.embedding operation, all the dimensions of output will be treated as # (batch dimension, embedding dimension), and then the sharding spec will be generated based @@ -223,8 +234,7 @@ def post_process(self, strategy: ShardingStrategy): # create multiple sharding strategies for the inputs # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy, - input_name=str( - self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py index 53addb873d1d..dcf0a1760a2c 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py @@ -4,7 +4,7 @@ from .node_handler import NodeHandler from .strategy import GetattrGenerator, StrategyGenerator -__all__ = ['GetattrHandler'] +__all__ = ["GetattrHandler"] class GetattrHandler(NodeHandler): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py index 3466e9dd9940..bd342c12eda9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py @@ -8,7 +8,7 @@ from .registry import operator_registry from .strategy import StrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator -__all__ = ['GetItemHandler'] +__all__ = ["GetItemHandler"] @operator_registry.register(operator.getitem) @@ -30,9 +30,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) physical_other_operand = OperationData(name="index", type=OperationDataType.ARG, data=self.node.args[1]) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py index 452381169b74..ce6b20fa1d24 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py @@ -3,11 +3,11 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoModuleHandler, ModuleHandler +from .node_handler import MetaInfoModuleHandler from .registry import operator_registry from .strategy import LayerNormGenerator, StrategyGenerator -__all__ = ['LayerNormModuleHandler'] +__all__ = ["LayerNormModuleHandler"] @operator_registry.register(torch.nn.LayerNorm) @@ -25,20 +25,22 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=self.named_parameters['weight'].shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=self.named_parameters["weight"].shape, + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if self.named_parameters['bias'] is not None: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + if self.named_parameters["bias"] is not None: + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index ea541e434009..4177af4eaf71 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -3,24 +3,21 @@ import torch import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.utils import ( - check_sharding_spec_validity, - transpose_partition_dim, - update_partition_dim, -) +from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingNotDivisibleError -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector -from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler, ModuleHandler, NodeHandler +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy +from .node_handler import MetaInfoModuleHandler, MetaInfoNodeHandler from .registry import operator_registry from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator -__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] +__all__ = ["LinearModuleHandler", "LinearFunctionHandler"] -def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy, - weight_name: str) -> ShardingStrategy: +def _update_sharding_spec_for_transposed_weight_for_linear( + strategy: ShardingStrategy, weight_name: str +) -> ShardingStrategy: """ This function is a helper function used by both module node handler and function node handler. This function will convert the sharding spec for the transposed weight to the correct partition spec. @@ -32,16 +29,17 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr # switch the dimensions of the transposed weight sharding_spec = strategy.get_sharding_spec_by_name(weight_name) op_data = strategy.get_op_data_by_name(weight_name) - assert op_data.logical_shape[0] == op_data.data.shape[1] and \ - op_data.logical_shape[1] == op_data.data.shape[0], \ - "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" + assert ( + op_data.logical_shape[0] == op_data.data.shape[1] and op_data.logical_shape[1] == op_data.data.shape[0] + ), "Expected the logical shape of the linear operator's weight is equal to transposed physical shape" dim_size = len(op_data.logical_shape) transpose_partition_dim(sharding_spec, 0, dim_size - 1) return strategy -def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str, - output_name: str) -> List[ShardingStrategy]: +def _convert_logical_sharding_to_physical_sharding_spec_for_linear( + strategy: ShardingStrategy, input_name: str, output_name: str +) -> List[ShardingStrategy]: """ This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output should have the same sharding spec. @@ -99,22 +97,26 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha input_dim_mapping = {0: i} input_dim_mapping.update(input_last_dim_mapping) - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping=input_dim_mapping, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True, + ) output_dim_mapping = {0: i} output_dim_mapping.update(output_last_dim_mapping) - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=output_dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) - strategy_copy.name = f'{strategy.name}_{i}' + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) + strategy_copy.name = f"{strategy.name}_{i}" sharding_strategies.append(strategy_copy) except ShardingNotDivisibleError as e: logger.debug( - f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}' + f"Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}" ) else: # the generated sharding strategy does not shard the non-matrix dimension, @@ -127,17 +129,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha # after updating, the logical shape will be replaced by the physical shape input_dim_mapping = {} input_dim_mapping.update(input_last_dim_mapping) - update_partition_dim(sharding_spec=input_sharding_spec, - dim_mapping=input_dim_mapping, - physical_shape=input_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=input_sharding_spec, + dim_mapping=input_dim_mapping, + physical_shape=input_op_data.data.shape, + inplace=True, + ) output_dim_mapping = {} output_dim_mapping.update(output_last_dim_mapping) - update_partition_dim(sharding_spec=output_sharding_spec, - dim_mapping=output_dim_mapping, - physical_shape=output_op_data.data.shape, - inplace=True) + update_partition_dim( + sharding_spec=output_sharding_spec, + dim_mapping=output_dim_mapping, + physical_shape=output_op_data.data.shape, + inplace=True, + ) sharding_strategies.append(strategy_copy) return sharding_strategies @@ -152,10 +158,13 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, - self.device_mesh, - linear_projection_type='linear', - solver_perference=self.solver_perference)) + LinearProjectionStrategyGenerator( + op_data_mapping, + self.device_mesh, + linear_projection_type="linear", + solver_perference=self.solver_perference, + ) + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -163,28 +172,34 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # the strategies will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=input_meta_data, - logical_shape=input_logical_shape) - physical_other_operand = OperationData(name="weight", - type=OperationDataType.PARAM, - data=self.named_parameters['weight'], - logical_shape=self.named_parameters['weight'].shape[::-1]) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=input_meta_data, + logical_shape=input_logical_shape, + ) + physical_other_operand = OperationData( + name="weight", + type=OperationDataType.PARAM, + data=self.named_parameters["weight"], + logical_shape=self.named_parameters["weight"].shape[::-1], + ) output_meta_data = self.node._meta_data output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape - physical_output = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=output_meta_data, - logical_shape=output_logical_shape) + physical_output = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=output_meta_data, + logical_shape=output_logical_shape, + ) mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if 'bias' in self.named_parameters is not None: - physical_bias_operand = OperationData(name="bias", - type=OperationDataType.PARAM, - data=self.named_parameters['bias']) - mapping['bias'] = physical_bias_operand + if "bias" in self.named_parameters is not None: + physical_bias_operand = OperationData( + name="bias", type=OperationDataType.PARAM, data=self.named_parameters["bias"] + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: @@ -194,14 +209,14 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li 2. the input and output sharding specs are updated to physical shape. """ # switch the dimensions of the transposed weight - strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight') + strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name="weight") # create multiple sharding strategies for the inputs # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, - input_name=str(self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies @@ -215,7 +230,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear") + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -223,10 +239,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # the strategies will be transformed back to its original shape in self.post_process input_meta_data = self.node.args[0]._meta_data input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data, - logical_shape=input_logical_shape) + physical_input_operand = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data, + logical_shape=input_logical_shape, + ) # check if the other operand is a parameter if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): @@ -234,10 +252,12 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: else: data_type = OperationDataType.ARG - physical_other_operand = OperationData(name=str(self.node.args[1]), - type=data_type, - data=self.node.args[1]._meta_data, - logical_shape=self.node.args[1]._meta_data.shape[::-1]) + physical_other_operand = OperationData( + name=str(self.node.args[1]), + type=data_type, + data=self.node.args[1]._meta_data, + logical_shape=self.node.args[1]._meta_data.shape[::-1], + ) output_meta_data = self.node._meta_data output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape physical_output = OperationData( @@ -249,27 +269,28 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} - if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None: + if "bias" in self.node.kwargs and self.node.kwargs["bias"] is not None: # check if the other operand is a parameter if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter): data_type = OperationDataType.PARAM else: data_type = OperationDataType.ARG - physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]), - type=data_type, - data=self.node.kwargs["bias"]._meta_data) - mapping['bias'] = physical_bias_operand + physical_bias_operand = OperationData( + name=str(self.node.kwargs["bias"]), type=data_type, data=self.node.kwargs["bias"]._meta_data + ) + mapping["bias"] = physical_bias_operand return mapping def post_process(self, strategy: ShardingStrategy): # switch the dimensions of the transposed weight - strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, - weight_name=str(self.node.args[1])) + strategy = _update_sharding_spec_for_transposed_weight_for_linear( + strategy=strategy, weight_name=str(self.node.args[1]) + ) # create multiple sharding strategies for the inputs # as input can be multi-dimensional and the partition dim is only 2D, # we need to map the partition at dim 0 to one of the first few dimensions of the input - strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy, - input_name=str(self.node.args[0]), - output_name=str(self.node)) + strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear( + strategy=strategy, input_name=str(self.node.args[0]), output_name=str(self.node) + ) return strategies diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py index fa51114a5c94..4fab5f7f05eb 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py @@ -16,7 +16,7 @@ from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import recover_sharding_spec_for_broadcast_shape -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import ( BatchedMatMulStrategyGenerator, @@ -37,6 +37,7 @@ class MatMulType(Enum): MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D """ + DOT = 0 MM = 1 MV = 2 @@ -92,26 +93,26 @@ def __init__(self) -> None: def apply(self, shape_mapping: Dict[str, List[int]]): mapping_copy = deepcopy(shape_mapping) - input_shape = mapping_copy['input'] - other_shape = mapping_copy['other'] + input_shape = mapping_copy["input"] + other_shape = mapping_copy["other"] if len(input_shape) == 1: # if the input is a 1D tensor, 1 is prepended to its shape # and it will be removed afterwards input_shape.insert(0, 1) - self.padded_dim_mapping['input'] = -2 - self.padded_dim_mapping['output'] = -2 + self.padded_dim_mapping["input"] = -2 + self.padded_dim_mapping["output"] = -2 elif len(other_shape) == 1: # if the other is a 1D tensor, 1 is appended to its shape # and it will be removed afterwards other_shape = other_shape.append(1) - self.padded_dim_mapping['other'] = -1 - self.padded_dim_mapping['output'] = -1 + self.padded_dim_mapping["other"] = -1 + self.padded_dim_mapping["output"] = -1 return mapping_copy def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): - input_op_data = op_data_mapping['input'] - other_op_data = op_data_mapping['other'] + op_data_mapping["input"] + op_data_mapping["other"] def _remove_padded_dim(key, strategy): op_data = op_data_mapping[key] @@ -131,7 +132,7 @@ def _remove_padded_dim(key, strategy): # compute unpadded tensor shape tensor_shape.pop(padded_dim) - assert tensor_shape == list(op_data.data.shape), f'{tensor_shape} vs {list(op_data.data.shape)}' + assert tensor_shape == list(op_data.data.shape), f"{tensor_shape} vs {list(op_data.data.shape)}" # update sharding spec sharding_spec.__init__(sharding_spec.device_mesh, tensor_shape, unpadded_dim_partition_list) @@ -142,15 +143,15 @@ def _remove_padded_dim(key, strategy): strategy_copy = strategy.clone() # only one of input and other will be padded - if 'input' in self.padded_dim_mapping: - _remove_padded_dim('input', strategy_copy) - _remove_padded_dim('output', strategy_copy) - elif 'other' in self.padded_dim_mapping: - _remove_padded_dim('other', strategy_copy) - _remove_padded_dim('output', strategy_copy) + if "input" in self.padded_dim_mapping: + _remove_padded_dim("input", strategy_copy) + _remove_padded_dim("output", strategy_copy) + elif "other" in self.padded_dim_mapping: + _remove_padded_dim("other", strategy_copy) + _remove_padded_dim("output", strategy_copy) strategies.append(strategy_copy) - except ShardingSpecException as e: + except ShardingSpecException: pass return strategies @@ -167,8 +168,8 @@ def apply(self, shape_mapping: Dict[str, List[int]]): mapping_copy = shape_mapping.copy() # get shapes - input_shape = mapping_copy['input'] - other_shape = mapping_copy['other'] + input_shape = mapping_copy["input"] + other_shape = mapping_copy["other"] # sanity check assert len(input_shape) > 1 and len(other_shape) > 1 @@ -179,16 +180,16 @@ def apply(self, shape_mapping: Dict[str, List[int]]): # store the broadcast dim info input_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, input_shape[:-2]) other_broadcast_dim_info = get_broadcast_dim_info(bcast_non_matrix_dims, other_shape[:-2]) - self.broadcast_dim_info['input'] = input_broadcast_dim_info - self.broadcast_dim_info['other'] = other_broadcast_dim_info + self.broadcast_dim_info["input"] = input_broadcast_dim_info + self.broadcast_dim_info["other"] = other_broadcast_dim_info # create the full logical shape input_shape = bcast_non_matrix_dims + input_shape[-2:] other_shape = bcast_non_matrix_dims + other_shape[-2:] assert len(input_shape) == len(other_shape) - mapping_copy['input'] = input_shape - mapping_copy['other'] = other_shape + mapping_copy["input"] = input_shape + mapping_copy["other"] = other_shape return mapping_copy @@ -216,17 +217,18 @@ def _remove_sharding_on_broadcast_dim(key, strategy): physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( logical_sharding_spec=sharding_spec, logical_shape=sharding_spec.entire_shape, - physical_shape=tensor_shape_before_broadcast) + physical_shape=tensor_shape_before_broadcast, + ) strategy.sharding_specs[op_data] = physical_sharding_spec # enumerate all sharding strategies strategies = [] try: strategy_copy = strategy.clone() - _remove_sharding_on_broadcast_dim('input', strategy_copy) - _remove_sharding_on_broadcast_dim('other', strategy_copy) + _remove_sharding_on_broadcast_dim("input", strategy_copy) + _remove_sharding_on_broadcast_dim("other", strategy_copy) strategies.append(strategy_copy) - except ShardingSpecException as e: + except ShardingSpecException: pass return strategies @@ -241,20 +243,20 @@ def __init__(self) -> None: def apply(self, shape_mapping: Dict[str, List[int]]): mapping_copy = shape_mapping.copy() - self.batch_dims_before_view = list(mapping_copy['input'][:-2]) + self.batch_dims_before_view = list(mapping_copy["input"][:-2]) # get shapes - input_shape = shape_mapping['input'] - other_shape = shape_mapping['other'] + input_shape = shape_mapping["input"] + other_shape = shape_mapping["other"] # view to 3d tensor assert len(input_shape) >= 3 and len(other_shape) >= 3 input_shape = [reduce(operator.mul, input_shape[:-2])] + input_shape[-2:] other_shape = [reduce(operator.mul, other_shape[:-2])] + other_shape[-2:] output_shape = input_shape[:2] + other_shape[2:] - mapping_copy['input'] = input_shape - mapping_copy['other'] = other_shape - mapping_copy['output'] = output_shape + mapping_copy["input"] = input_shape + mapping_copy["other"] = other_shape + mapping_copy["output"] = output_shape return mapping_copy def recover(self, op_data_mapping: Dict[str, OperationData], strategy: ShardingStrategy): @@ -291,11 +293,11 @@ def _update_sharding_spec(key, strategy, physical_batch_dim): # create a new strategy strategy_copy = strategy.clone() try: - _update_sharding_spec('input', strategy_copy, i) - _update_sharding_spec('other', strategy_copy, i) - _update_sharding_spec('output', strategy_copy, i) + _update_sharding_spec("input", strategy_copy, i) + _update_sharding_spec("other", strategy_copy, i) + _update_sharding_spec("output", strategy_copy, i) strategies.append(strategy_copy) - except ShardingSpecException as e: + except ShardingSpecException: continue return strategies @@ -312,14 +314,14 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms): 3. reshape to 3 dimensions """ - shape_mapping = {'input': input_shape, 'other': other_shape} + shape_mapping = {"input": input_shape, "other": other_shape} for transform in transforms: shape_mapping = transform.apply(shape_mapping) - input_shape = shape_mapping.get('input', None) - other_shape = shape_mapping.get('other', None) - output_shape = shape_mapping.get('output', None) + input_shape = shape_mapping.get("input", None) + other_shape = shape_mapping.get("other", None) + output_shape = shape_mapping.get("output", None) return input_shape, other_shape, output_shape @@ -364,7 +366,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh)) elif self.matmul_type == MatMulType.MM: generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type="linear") + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -372,7 +375,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: MatMulType.DOT: self._get_logical_shape_for_dot, MatMulType.MM: self._get_logical_shape_for_mm, MatMulType.MV: self._get_logical_shape_for_mv, - MatMulType.BMM: self._get_logical_shape_for_bmm + MatMulType.BMM: self._get_logical_shape_for_bmm, } logical_shapes = logical_shape_func[self.matmul_type]() op_data_mapping = self._get_op_data_mapping(*logical_shapes) @@ -390,20 +393,26 @@ def _get_op_data_mapping(self, input_logical_shape, other_logical_shape, output_ output_logical_shape = torch.Size(output_logical_shape) # create op data - input_op_data = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.input_meta_data, - logical_shape=input_logical_shape) - other_op_data = OperationData(name=str(self.node.args[1]), - type=OperationDataType.ARG, - data=self.other_meta_data, - logical_shape=other_logical_shape) - output_op_data = OperationData(name=str(self.node), - type=OperationDataType.OUTPUT, - data=self.output_meta_data, - logical_shape=output_logical_shape) - - mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data} + input_op_data = OperationData( + name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.input_meta_data, + logical_shape=input_logical_shape, + ) + other_op_data = OperationData( + name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.other_meta_data, + logical_shape=other_logical_shape, + ) + output_op_data = OperationData( + name=str(self.node), + type=OperationDataType.OUTPUT, + data=self.output_meta_data, + logical_shape=output_logical_shape, + ) + + mapping = {"input": input_op_data, "other": other_op_data, "output": output_op_data} return mapping def _get_logical_shape_for_dot(self): @@ -460,9 +469,11 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li dim_partition_dict[0] = shard # re-init the sharding spec - input_sharding_spec.__init__(input_sharding_spec.device_mesh, - entire_shape=input_physical_shape, - dim_partition_dict=dim_partition_dict) + input_sharding_spec.__init__( + input_sharding_spec.device_mesh, + entire_shape=input_physical_shape, + dim_partition_dict=dim_partition_dict, + ) return strategy else: return strategy @@ -481,7 +492,8 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li recovered_stragies.extend(output) else: raise TypeError( - f"Found unexpected output type {type(output)} from the recover method of BmmTransform") + f"Found unexpected output type {type(output)} from the recover method of BmmTransform" + ) strategies = recovered_stragies for index, strategies in enumerate(strategies): strategies.name = f"{strategies.name}_{index}" diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index b4b7b0e794d1..d2bad39dcbb9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -8,7 +8,6 @@ from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, - OperationDataType, ShardingSpec, ShardingStrategy, StrategiesVector, @@ -23,21 +22,23 @@ class NodeHandler(ABC): - ''' + """ The NodeHandler is an abstract class used to generate every possible strategies for an operator node. Args: node (Node): the input node in node argument list. device_mesh (DeviceMesh): A logical view of a physical mesh. strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. - ''' - - def __init__(self, - node: Node, - device_mesh: DeviceMesh, - strategies_vector: StrategiesVector, - shard_option: ShardOption = ShardOption.STANDARD, - solver_perference: SolverPerference = SolverPerference.STANDARD) -> None: + """ + + def __init__( + self, + node: Node, + device_mesh: DeviceMesh, + strategies_vector: StrategiesVector, + shard_option: ShardOption = ShardOption.STANDARD, + solver_perference: SolverPerference = SolverPerference.STANDARD, + ) -> None: self.node = node self.predecessor_node = list(node._input_nodes.keys()) self.successor_node = list(node.users.keys()) @@ -68,8 +69,9 @@ def update_resharding_cost(self, strategy: ShardingStrategy) -> None: current_sharding_spec = strategy.sharding_specs[op_data] # get the sharding specs for this node generated # in its own node handler - assert hasattr(node, 'strategies_vector'), \ - f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.' + assert hasattr( + node, "strategies_vector" + ), f"The predecessor node {node_name} has no strategy vector to compute the resharding cost." prev_strategy_vector = node.strategies_vector prev_sharding_specs = [ prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector @@ -80,10 +82,10 @@ def update_resharding_cost(self, strategy: ShardingStrategy) -> None: resharding_costs[node] = [] def _compute_resharding_cost( - prev_sharding_spec: Union[ShardingSpec, - List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec, - List[ShardingSpec]], - data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem: + prev_sharding_spec: Union[ShardingSpec, List[ShardingSpec]], + current_sharding_spec: Union[ShardingSpec, List[ShardingSpec]], + data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + ) -> TrainCycleItem: """ This is a helper function to compute the resharding cost for a specific strategy of a node. """ @@ -94,30 +96,35 @@ def _compute_resharding_cost( dtype = data.dtype size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() _, _, consistency_cost = shape_consistency_manager.shape_consistency( - prev_sharding_spec, current_sharding_spec) - - resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes, - bwd=consistency_cost["backward"] * size_per_elem_bytes, - total=consistency_cost["total"] * size_per_elem_bytes) + prev_sharding_spec, current_sharding_spec + ) + + resharding_cost = TrainCycleItem( + fwd=consistency_cost["forward"] * size_per_elem_bytes, + bwd=consistency_cost["backward"] * size_per_elem_bytes, + total=consistency_cost["total"] * size_per_elem_bytes, + ) return resharding_cost else: # This raise is used to check if we have missed any type of data. # It could be merged into Parameter branch, which means we won't handle # non-tensor arguments. - raise ValueError(f'Unsupported data type {type(data)}') + raise ValueError(f"Unsupported data type {type(data)}") else: - assert isinstance(prev_sharding_spec, (tuple, list)), \ - f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ - or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}' + assert isinstance( + prev_sharding_spec, (tuple, list) + ), f"prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ + or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}" fwd_cost = 0 bwd_cost = 0 total_cost = 0 - for index, (prev_sharding_spec_item, - current_sharding_spec_item) in enumerate(zip(prev_sharding_spec, - current_sharding_spec)): - item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item, - data[index]) + for index, (prev_sharding_spec_item, current_sharding_spec_item) in enumerate( + zip(prev_sharding_spec, current_sharding_spec) + ): + item_cost = _compute_resharding_cost( + prev_sharding_spec_item, current_sharding_spec_item, data[index] + ) fwd_cost += item_cost.fwd bwd_cost += item_cost.bwd total_cost += item_cost.total @@ -138,17 +145,17 @@ def get_target_function(self) -> callable: This function is used to get the target function for the node handler. The target function is used to analyze the costs of strategies. """ - if self.node.op in ('placeholder', 'get_attr', 'output'): + if self.node.op in ("placeholder", "get_attr", "output"): return None - if self.node.op == 'call_module': + if self.node.op == "call_module": target = self.node.graph.owning_module.get_submodule(self.node.target) - elif self.node.op == 'call_function': + elif self.node.op == "call_function": target = self.node.target - elif self.node.op == 'call_method': + elif self.node.op == "call_method": target = getattr(self.node.args[0]._meta_data.__class__, self.node.target) else: - raise ValueError(f'Unsupported node type: {self.node.op}') + raise ValueError(f"Unsupported node type: {self.node.op}") return target @@ -221,7 +228,6 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: """ Define which generators should be used by this NodeHandler object. """ - pass @abstractmethod def get_operation_data_mapping(self) -> Dict[str, OperationData]: @@ -244,7 +250,6 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: "output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data), } """ - pass class MetaInfoNodeHandler(NodeHandler): @@ -278,19 +283,19 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV else: logger = get_dist_logger() - logger.warning(f'The target function {target} is not patched yet, ') + logger.warning(f"The target function {target} is not patched yet, ") return self.strategies_vector class ModuleHandler(NodeHandler): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # set attributes to access module parameters for convenience - assert self.node.graph.owning_module is not None, \ - f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' + assert ( + self.node.graph.owning_module is not None + ), f"The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object." module = self.node.graph.owning_module.get_submodule(self.node.target) named_parameters = list(module.named_parameters(recurse=False)) named_buffers = list(module.named_buffers(recurse=False)) @@ -333,6 +338,6 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV else: logger = get_dist_logger() - logger.warning(f'The target function {target} is not patched yet') + logger.warning(f"The target function {target} is not patched yet") return self.strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py index 4e71ccba95a7..facf19560596 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py @@ -3,11 +3,11 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoModuleHandler, ModuleHandler +from .node_handler import MetaInfoModuleHandler from .registry import operator_registry from .strategy import NormalPoolStrategyGenerator, StrategyGenerator -__all__ = ['NormPoolingHandler'] +__all__ = ["NormPoolingHandler"] @operator_registry.register(torch.nn.MaxPool1d) @@ -30,9 +30,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) physical_weight_operand = OperationData(name="kernel", type=OperationDataType.ARG, data=self.module.kernel_size) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py index ed120a8c3d6d..89906a205e87 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py @@ -8,7 +8,7 @@ from .node_handler import NodeHandler from .strategy import OutputGenerator, StrategyGenerator -__all__ = ['OutputHandler'] +__all__ = ["OutputHandler"] class OutputHandler(NodeHandler): @@ -16,8 +16,9 @@ class OutputHandler(NodeHandler): A OutputHandler which deals with the sharding strategies for Output Node. """ - def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, - output_option: str) -> None: + def __init__( + self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, output_option: str + ) -> None: super().__init__(node, device_mesh, strategies_vector) self.output_option = output_option @@ -35,11 +36,11 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: for index, input_node in enumerate(self.predecessor_node): input_meta_data = input_node._meta_data physical_inputs = OperationData(name=str(input_node), type=OperationDataType.ARG, data=input_meta_data) - name_key = f'input_{index}' + name_key = f"input_{index}" mapping[name_key] = physical_inputs output_meta_data.append(input_meta_data) - assert len(output_meta_data) > 0, f'Output node {self.node} has no input node.' + assert len(output_meta_data) > 0, f"Output node {self.node} has no input node." if len(output_meta_data) == 1: output_meta_data = output_meta_data[0] else: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py index 91e4a5105a08..75f07168e47b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import PermuteGenerator, StrategyGenerator -__all__ = ['PermuteHandler'] +__all__ = ["PermuteHandler"] @operator_registry.register(torch.Tensor.permute) @@ -34,14 +34,14 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) permute_dims = [] - if self.node.op == 'call_method': + if self.node.op == "call_method": # torch.Tensor.permute (input, *dims) for arg in self.node.args: if isinstance(arg, torch.fx.Node): if isinstance(arg._meta_data, int): permute_dims.append(arg._meta_data) else: - assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.' + assert isinstance(arg, int), "The argument in permute node should be either type of Node or int." permute_dims.append(arg) else: # torch.permute (input, dims) @@ -51,8 +51,8 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: permute_dims.extend(arg._meta_data) else: assert isinstance( - arg, - (tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].' + arg, (tuple, list) + ), "The argument in permute node should be type of Node, Tuple[int] or List[int]." permute_dims.extend(arg) num_dims = self.node._meta_data.dim() @@ -61,7 +61,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: if permute_dims[i] < 0: permute_dims[i] += num_dims - physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims)) + physical_shape_operand = OperationData(name="permute_dims", type=OperationDataType.ARG, data=list(permute_dims)) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -69,7 +69,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "permute_dims": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py index e4f40fc935a4..461bc2935780 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py @@ -8,7 +8,7 @@ from .node_handler import NodeHandler from .strategy import PlaceholderGenerator, StrategyGenerator -__all__ = ['PlaceholderHandler'] +__all__ = ["PlaceholderHandler"] class PlaceholderHandler(NodeHandler): @@ -16,8 +16,9 @@ class PlaceholderHandler(NodeHandler): A PlaceholderHandler which deals with the sharding strategies for Placeholder Node. """ - def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, - placeholder_option: str) -> None: + def __init__( + self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, placeholder_option: str + ) -> None: super().__init__(node, device_mesh, strategies_vector) self.placeholder_option = placeholder_option @@ -25,7 +26,8 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option)) + PlaceholderGenerator(op_data_mapping, self.device_mesh, placeholder_option=self.placeholder_option) + ) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py index 730a90d74cf8..f663fc9695d3 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/registry.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/registry.py @@ -1,11 +1,9 @@ class Registry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): if isinstance(source, (list, tuple)): # support register a list of items for this func @@ -18,7 +16,7 @@ def wrapper(func): return wrapper def get(self, source): - assert source in self.store, f'{source} not found in the {self.name} registry' + assert source in self.store, f"{source} not found in the {self.name} registry" target = self.store[source] return target @@ -26,4 +24,4 @@ def has(self, source): return source in self.store -operator_registry = Registry('operator') +operator_registry = Registry("operator") diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py index 743a1f90eaaf..6e883ea64736 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import SoftmaxGenerator, StrategyGenerator -__all__ = ['SoftmaxHandler'] +__all__ = ["SoftmaxHandler"] @operator_registry.register(torch.nn.Softmax) @@ -34,14 +34,14 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: input_data = self.node.args[0]._meta_data physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) - softmax_dim = self.node.kwargs['dim'] + softmax_dim = self.node.kwargs["dim"] num_dims = self.node.args[0]._meta_data.dim() # recover negative value to positive if softmax_dim < 0: softmax_dim += num_dims - physical_dim_operand = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + physical_dim_operand = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -49,7 +49,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "softmax_dim": physical_dim_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py index 653d158b7c36..4c32529a5d5b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import SplitGenerator, StrategyGenerator -__all__ = ['SplitHandler'] +__all__ = ["SplitHandler"] @operator_registry.register(torch.Tensor.split) @@ -38,7 +38,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: split_dim = self.node.args[2] else: if self.node.kwargs: - split_dim = self.node.kwargs['dim'] + split_dim = self.node.kwargs["dim"] else: split_dim = 0 @@ -48,7 +48,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: split_dim += num_dims split_info = (split_size, split_dim) - physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info) + physical_shape_operand = OperationData(name="split_info", type=OperationDataType.ARG, data=split_info) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -56,7 +56,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "split_info": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py index db1f31521c86..1fc7f613716b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py @@ -29,11 +29,31 @@ from .where_generator import WhereGenerator __all__ = [ - 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', - 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', - 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', - 'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator', - 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator', - 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator', - 'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator' + "StrategyGenerator", + "DotProductStrategyGenerator", + "MatVecStrategyGenerator", + "LinearProjectionStrategyGenerator", + "BatchedMatMulStrategyGenerator", + "ConvStrategyGenerator", + "UnaryElementwiseGenerator", + "BatchNormStrategyGenerator", + "GetItemStrategyGenerator", + "TensorStrategyGenerator", + "TensorTupleStrategyGenerator", + "LayerNormGenerator", + "PlaceholderGenerator", + "OutputGenerator", + "WhereGenerator", + "NormalPoolStrategyGenerator", + "BinaryElementwiseStrategyGenerator", + "GetattrGenerator", + "TensorConstructorGenerator", + "EmbeddingStrategyGenerator", + "SumGenerator", + "SoftmaxGenerator", + "ViewGenerator", + "PermuteGenerator", + "TransposeGenerator", + "SplitGenerator", + "DefaultReshapeGenerator", ] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index 416dc9c29cad..9c766b1014c8 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -14,7 +14,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['BatchNormStrategyGenerator'] +__all__ = ["BatchNormStrategyGenerator"] class BatchNormStrategyGenerator(StrategyGenerator): @@ -30,28 +30,31 @@ class BatchNormStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - ''' + """ In sanity check, we need make sure the input data having correct dimension size. For BatchNorm1d, the dim of input data should be 3([N, C, L]). For BatchNorm2d, the dim of input data should be 4([N, C, H, W]). For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - input_op_data = self.op_data['input'] + """ + input_op_data = self.op_data["input"] assert input_op_data.data.dim() in ( - 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + 3, + 4, + 5, + ), f"We suppose the dim of input fed into conv op should in range of [3, 5]." def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. - ''' + """ # TODO: a constant coefficient need to be added. # 1D: (L) * N * Cin # 2D: (H * W) * N * Cin # 3D: (H * W * D) * N * Cin - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() if self.has_bias: # bias add is an element wise operation, so the cost is equal to product of output shape. bias_compute_cost = reduce(operator.mul, sharded_output_shape) @@ -69,23 +72,24 @@ def update_compute_cost(self, strategy: ShardingStrategy): def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output"), - 'running_mean': self._compute_size_in_bytes(strategy, "running_mean"), - 'running_var': self._compute_size_in_bytes(strategy, "running_var"), + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), + "running_mean": self._compute_size_in_bytes(strategy, "running_mean"), + "running_var": self._compute_size_in_bytes(strategy, "running_var"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - forward_size_mapping['bias'] = bias_size + forward_size_mapping["bias"] = bias_size backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping.pop("output") # compute fwd cost incurred # fwd_cost = input + other + bias + output fwd_activation_cost = sum( - [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + [v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)] + ) fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)]) fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost) @@ -93,36 +97,29 @@ def update_memory_cost(self, strategy: ShardingStrategy): # compute bwd cost incurred # bwd_cost = input_grad + other_grad + bias_grad bwd_activation_cost = sum( - [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)]) + [v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)] + ) bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost, - buffer=fwd_buffer_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost, + buffer=fwd_buffer_cost, + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @ignore_sharding_exception def split_input_channel(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' + name = f"RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}" dim_partition_dict_mapping = { - "input": { - 1: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0] - }, - "output": { - 1: [mesh_dim_0] - }, - "running_mean": { - 0: [mesh_dim_0] - }, - "running_var": { - 0: [mesh_dim_0] - }, + "input": {1: [mesh_dim_0]}, + "other": {0: [mesh_dim_0]}, + "output": {1: [mesh_dim_0]}, + "running_mean": {0: [mesh_dim_0]}, + "running_var": {0: [mesh_dim_0]}, "num_batches_tracked": {}, } if self.has_bias: @@ -132,29 +129,21 @@ def split_input_channel(self, mesh_dim_0): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}" dim_partition_dict_mapping = { - "input": { - 1: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "output": { - 1: [mesh_dim_0, mesh_dim_1] - }, - "running_mean": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "running_var": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {1: [mesh_dim_0, mesh_dim_1]}, + "other": {0: [mesh_dim_0, mesh_dim_1]}, + "output": {1: [mesh_dim_0, mesh_dim_1]}, + "running_mean": {0: [mesh_dim_0, mesh_dim_1]}, + "running_var": {0: [mesh_dim_0, mesh_dim_1]}, "num_batches_tracked": {}, } if self.has_bias: @@ -164,13 +153,15 @@ def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def non_split(self): - name = f'RR = RR x R' + name = f"RR = RR x R" dim_partition_dict_mapping = { "input": {}, "other": {}, @@ -186,21 +177,19 @@ def non_split(self): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, + "input": {0: [mesh_dim_0]}, "other": {}, - "output": { - 0: [mesh_dim_0] - }, + "output": {0: [mesh_dim_0]}, "running_mean": {}, "running_var": {}, "num_batches_tracked": {}, @@ -218,27 +207,26 @@ def split_input_batch(self, mesh_dim_0): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.IMPLICIT) + comm_type=CommType.IMPLICIT, + ) # TODO: Temporary solution has no communication cost, # above action should be added after the SyncBN replace pass completed. communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, "other": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "output": {0: [mesh_dim_0, mesh_dim_1]}, "running_mean": {}, "running_var": {}, "num_batches_tracked": {}, @@ -256,19 +244,22 @@ def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.IMPLICIT) + comm_type=CommType.IMPLICIT, + ) # TODO: Temporary solution has no communication cost, # above action should be added after the SyncBN replace pass completed. communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN" dim_partition_dict_mapping = { "input": { 0: [mesh_dim_0], @@ -304,20 +295,23 @@ def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0], - comm_type=CommType.IMPLICIT) + comm_type=CommType.IMPLICIT, + ) # TODO: Temporary solution has no communication cost, # above action should be added after the SyncBN replace pass completed. communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: - ''' + """ Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. - ''' + """ strategy_list = [] # RS = RS x S diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py index d27cc046eaf3..c7da0034ec3b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py @@ -14,7 +14,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['BinaryElementwiseStrategyGenerator'] +__all__ = ["BinaryElementwiseStrategyGenerator"] class BinaryElementwiseStrategyGenerator(StrategyGenerator): @@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - assert len(self.op_data) == 3, \ - f'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}' + assert ( + len(self.op_data) == 3 + ), f"BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got {len(self.op_data)}" for name, op_data in self.op_data.items(): if not isinstance(op_data.data, (torch.Tensor, int, float)): - raise TypeError(f'The operation data {name} is not a torch.Tensor/int/float.') + raise TypeError(f"The operation data {name} is not a torch.Tensor/int/float.") def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() # since elementwise ops are not compute-intensive, # we approximate the backward compute cost # to be twice the fwd compute cost fwd_compute_cost = reduce(operator.mul, shape) bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # all input, output and outputs have the same shape - shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() # compute fwd memory cost in bytes # as the elementwise ops are not memory-intensive # we approximate the fwd memory cost to be the output # and the backward memory cost to be grad of input and other - input_bytes = self._compute_size_in_bytes(strategy, 'input') - other_bytes = self._compute_size_in_bytes(strategy, 'other') - output_bytes = self._compute_size_in_bytes(strategy, 'output') + input_bytes = self._compute_size_in_bytes(strategy, "input") + other_bytes = self._compute_size_in_bytes(strategy, "other") + output_bytes = self._compute_size_in_bytes(strategy, "output") fwd_memory_cost = MemoryCost(activation=output_bytes) bwd_memory_cost = MemoryCost(activation=input_bytes + other_bytes) total_memory_cost = MemoryCost(activation=input_bytes + other_bytes + output_bytes) @@ -66,7 +67,7 @@ def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): # we check for the output logical shape to get the number of dimensions dim_partition_list = [] - dim_size = len(self.op_data['output'].logical_shape) + dim_size = len(self.op_data["output"].logical_shape) # enumerate all the 2D sharding cases sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) @@ -86,21 +87,22 @@ def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): # convert these dim partition dict to sharding strategy for dim_partition_dict in dim_partition_list: - dim_partition_dict_mapping = dict(input=dim_partition_dict, - other=dim_partition_dict, - output=dim_partition_dict) + dim_partition_dict_mapping = dict( + input=dim_partition_dict, other=dim_partition_dict, output=dim_partition_dict + ) try: sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) communication_action_mapping = {} # get name - sharding_seq = sharding_spec_mapping['input'].sharding_sequence - name = f'{sharding_seq} = {sharding_seq} {sharding_seq}' + sharding_seq = sharding_spec_mapping["input"].sharding_sequence + name = f"{sharding_seq} = {sharding_seq} {sharding_seq}" sharding_strategy = self.get_sharding_strategy( name=name, sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(sharding_strategy) except ShardingSpecException: continue diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index e605a68a326b..5208f61543bb 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -1,11 +1,9 @@ import copy import operator -import warnings from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, CommType, MemoryCost, ShardingStrategy, @@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - ''' + """ In sanity check, we need make sure the input data having correct dimension size. For Conv1d, the dim of input data should be 3([N, C, L]). For Conv2d, the dim of input data should be 4([N, C, H, W]). For Conv3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - input_op_data = self.op_data['input'] + """ + input_op_data = self.op_data["input"] assert input_op_data.data.dim() in ( - 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + 3, + 4, + 5, + ), f"We suppose the dim of input fed into conv op should in range of [3, 5]." def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. - ''' + """ # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size. # 1D: (L) * N * Cout * Cin * kernel # 2D: (H * W) * N * Cout * Cin * kernel # 3D: (H * W * D) * N * Cout * Cin * kernel - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() if self.has_bias: # bias add is an element wise operation, so the cost is equal to product of output shape. bias_compute_cost = reduce(operator.mul, sharded_output_shape) @@ -76,14 +77,14 @@ def update_compute_cost(self, strategy: ShardingStrategy): def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - forward_size_mapping['bias'] = bias_size + forward_size_mapping["bias"] = bias_size backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping.pop("output") @@ -100,26 +101,20 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @ignore_sharding_exception def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - 1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0]}, + "other": {1: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], 1: [mesh_dim_1]}, } if self.has_bias: dim_partition_dict_mapping["bias"] = {0: [mesh_dim_1]} @@ -132,7 +127,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} if self.is_param("other"): @@ -140,7 +136,8 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -148,38 +145,41 @@ def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: - if self.is_param('bias'): + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_batch(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}R x RR" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, + "input": {0: [mesh_dim_0]}, "other": {}, "output": { 0: [mesh_dim_0], @@ -196,7 +196,8 @@ def split_input_batch(self, mesh_dim_0): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -204,42 +205,45 @@ def split_input_batch(self, mesh_dim_0): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: - if self.is_param('bias'): + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R" dim_partition_dict_mapping = { "input": { 0: [mesh_dim_0], 1: [mesh_dim_1], }, - "other": { - 0: [mesh_dim_1] - }, + "other": {0: [mesh_dim_1]}, "output": { 0: [mesh_dim_0], }, @@ -254,7 +258,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} @@ -263,7 +268,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -271,7 +277,8 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: if self.is_param("bias"): @@ -279,23 +286,27 @@ def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}" dim_partition_dict_mapping = { "input": { @@ -322,23 +333,27 @@ def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) input_comm_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"output": output_comm_action, "input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_in_channel_weight_in_channel(self, mesh_dim_0): - name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' + name = f"RR = RS{mesh_dim_0} x S{mesh_dim_0}R" dim_partition_dict_mapping = { "input": { @@ -360,17 +375,20 @@ def split_input_in_channel_weight_in_channel(self, mesh_dim_0): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_weight_out_channel(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' + name = f"RS{mesh_dim_0} = RR x RS{mesh_dim_0}" dim_partition_dict_mapping = { "input": {}, @@ -395,17 +413,20 @@ def split_weight_out_channel(self, mesh_dim_0): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def non_split(self): - name = f'RR = RR x RR' + name = f"RR = RR x RR" dim_partition_dict_mapping = { "input": {}, @@ -418,13 +439,13 @@ def non_split(self): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) + return self.get_sharding_strategy( + name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={} + ) @ignore_sharding_exception def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR" dim_partition_dict_mapping = { "input": { @@ -447,14 +468,16 @@ def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action @@ -464,23 +487,27 @@ def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - key_for_kwarg='bias') + key_for_kwarg="bias", + ) communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R" dim_partition_dict_mapping = { "input": { 1: [mesh_dim_0, mesh_dim_1], @@ -501,17 +528,20 @@ def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}" dim_partition_dict_mapping = { "input": {}, "other": { @@ -535,13 +565,16 @@ def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategies = [] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py index 82a04ab52e73..385a8886f231 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py @@ -1,11 +1,9 @@ import copy import operator -import warnings from functools import reduce from typing import List from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, CommType, MemoryCost, ShardingStrategy, @@ -27,16 +25,16 @@ def validate(self) -> bool: return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. Note: The computation cost for the embedding handler is estimated as dense computing now. It may not be accurate. - ''' + """ # TODO: estimate the embedding computation cost as sparse operation - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() input_size_product = reduce(operator.mul, sharded_input_shape) other_size_product = reduce(operator.mul, sharded_other_shape) @@ -55,9 +53,9 @@ def update_compute_cost(self, strategy: ShardingStrategy): def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -75,14 +73,15 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @ignore_sharding_exception def non_split(self): - name = f'RR = R x RR' + name = f"RR = R x RR" dim_partition_dict_mapping = { "input": {}, @@ -92,18 +91,16 @@ def non_split(self): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) + return self.get_sharding_strategy( + name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={} + ) @ignore_sharding_exception def split_input(self, mesh_dim_0): - name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR' + name = f"S{mesh_dim_0}R = S{mesh_dim_0} x RR" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, + "input": {0: [mesh_dim_0]}, "other": {}, "output": { 0: [mesh_dim_0], @@ -118,7 +115,8 @@ def split_input(self, mesh_dim_0): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -126,17 +124,20 @@ def split_input(self, mesh_dim_0): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}" dim_partition_dict_mapping = { "input": { @@ -159,7 +160,8 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} if self.is_param("other"): @@ -167,7 +169,8 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -175,22 +178,23 @@ def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, "other": {}, "output": { 0: [mesh_dim_0, mesh_dim_1], @@ -207,7 +211,8 @@ def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( @@ -215,17 +220,20 @@ def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) communication_action_mapping["other"] = other_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_embedding_dim(self, mesh_dim_0): - name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}' + name = f"RS{mesh_dim_0} = R x RS{mesh_dim_0}" dim_partition_dict_mapping = { "input": {}, @@ -245,17 +253,20 @@ def split_embedding_dim(self, mesh_dim_0): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}" dim_partition_dict_mapping = { "input": {}, @@ -275,13 +286,16 @@ def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping = {"input": input_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategies = [] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py index bbeb9a639c83..cc8d5771f28e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py @@ -10,7 +10,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['GetattrGenerator'] +__all__ = ["GetattrGenerator"] class GetattrGenerator(StrategyGenerator): @@ -26,10 +26,10 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + """ + forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")} # compute fwd cost incurred # fwd_cost = output @@ -47,7 +47,7 @@ def update_memory_cost(self, strategy: ShardingStrategy): def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): # we check for the output logical shape to get the number of dimensions dim_partition_list = [] - dim_size = len(self.op_data['output'].logical_shape) + dim_size = len(self.op_data["output"].logical_shape) # enumerate all the 2D sharding cases sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size) @@ -78,7 +78,8 @@ def enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1): sharding_strategy = self.get_sharding_strategy( name=name, sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(sharding_strategy) except ShardingSpecException: continue diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index 0aeb2e0d4079..6f01d9cc7f8e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -1,19 +1,13 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from colossalai.logging import get_dist_logger -from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.sharding_spec import ShardingSpecException from .strategy_generator import FollowingStrategyGenerator -__all__ = ['GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator'] +__all__ = ["GetItemStrategyGenerator", "TensorStrategyGenerator", "TensorTupleStrategyGenerator"] class GetItemStrategyGenerator(FollowingStrategyGenerator): @@ -35,12 +29,12 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -58,27 +52,29 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost class TensorStrategyGenerator(GetItemStrategyGenerator): - ''' + """ Deal with case 1 and 2. - ''' + """ def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] - getitem_index = self.op_data['index'].data + getitem_index = self.op_data["index"].data for index, strategy in enumerate(self.predecessor_node.strategies_vector): try: logger = get_dist_logger() dim_partition_dict_mapping = {} communication_action_mapping = {} dim_partition_dict_for_input = copy.deepcopy( - strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict) + strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict + ) int_index = False if isinstance(getitem_index, int): @@ -120,9 +116,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) except ShardingSpecException as e: logger.debug(e) continue @@ -137,9 +135,9 @@ def collate_strategies(self) -> List[ShardingStrategy]: class TensorTupleStrategyGenerator(GetItemStrategyGenerator): - ''' + """ Deal with case 3. - ''' + """ def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] @@ -158,13 +156,15 @@ def collate_strategies(self) -> List[ShardingStrategy]: sharding_spec_mapping["input"] = sharding_spec_for_input input_sharding_info = f"get the {index} element from (" for sharding_spec in sharding_spec_for_input: - input_sharding_info += f'{sharding_spec.sharding_sequence}, ' + input_sharding_info += f"{sharding_spec.sharding_sequence}, " input_sharding_info += ")" name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py index 65b173bbf65d..e5b7e6f25d4d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -18,7 +18,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['LayerNormGenerator'] +__all__ = ["LayerNormGenerator"] class LayerNormGenerator(StrategyGenerator): @@ -31,21 +31,21 @@ def validate(self) -> bool: return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. - ''' + """ # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size. # TODO: a constant coefficient need to be added. - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_weight_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_weight_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() if self.has_bias: # bias add is an element wise operation, so the cost is equal to product of output shape. bias_compute_cost = reduce(operator.mul, sharded_weight_shape) # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization. - input_batch_shape = sharded_input_shape[:-len(sharded_weight_shape)] + input_batch_shape = sharded_input_shape[: -len(sharded_weight_shape)] input_batch_product = reduce(operator.mul, input_batch_shape, 1) norm_kernel_product = reduce(operator.mul, sharded_weight_shape, 1) forward_compute_cost = input_batch_product * norm_kernel_product @@ -62,18 +62,18 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - forward_size_mapping['bias'] = bias_size + forward_size_mapping["bias"] = bias_size backward_size_mapping = copy.deepcopy(forward_size_mapping) backward_size_mapping.pop("output") @@ -90,8 +90,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -120,7 +121,8 @@ def _generate_strategy_with_dim_partition(self, dim_partition): sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=total_mesh_dim_list, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) communication_action_mapping["other"] = other_comm_action if self.has_bias: @@ -128,12 +130,15 @@ def _generate_strategy_with_dim_partition(self, dim_partition): sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=total_mesh_dim_list, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) communication_action_mapping["bias"] = bias_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy @@ -155,7 +160,7 @@ def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1, batch_dimensio @ignore_sharding_exception def non_split(self): - name = f'RR = RR x R' + name = f"RR = RR x R" dim_partition_dict_mapping = { "input": {}, "other": {}, @@ -168,14 +173,16 @@ def non_split(self): communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: - ''' + """ Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector. - ''' + """ strategy_list = [] input_data_dim = len(self.op_data["input"].logical_shape) weight_data_dim = len(self.op_data["other"].logical_shape) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index aa1581b99e0f..fb182afb9175 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -1,5 +1,4 @@ import operator -from ast import arg from functools import reduce from typing import List @@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator): def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'other': self._compute_size_in_bytes(strategy, "other"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "other": self._compute_size_in_bytes(strategy, "other"), + "output": self._compute_size_in_bytes(strategy, "output"), } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") - size_mapping['bias'] = bias_size + size_mapping["bias"] = bias_size # compute fwd cost incurred # fwd_cost = input + other + bias + output @@ -41,45 +40,47 @@ def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # compute bwd cost incurred # bwd_cost = input_grad + bias_grad - bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']]) + bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ["input", "other", "bias"]]) bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + 0) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + 0 + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost class DotProductStrategyGenerator(MatMulStrategyGenerator): - def validate(self) -> bool: - input_op_data = self.op_data['input'] - other_op_data = self.op_data['other'] + input_op_data = self.op_data["input"] + other_op_data = self.op_data["other"] assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1 def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) return compute_cost @ignore_sharding_exception def no_split(self): - name = f'R = R dot R' - dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} + name = f"R = R dot R" + dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_one_dim(self, mesh_dim): - name = f'R = S{mesh_dim} dot S{mesh_dim}' + name = f"R = S{mesh_dim} dot S{mesh_dim}" # get sharding spec dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}} @@ -87,14 +88,17 @@ def split_one_dim(self, mesh_dim): # get communication action output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) communication_action_mapping = {"output": output_comm_action} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] @@ -112,19 +116,18 @@ def collate_strategies(self) -> List[ShardingStrategy]: class MatVecStrategyGenerator(MatMulStrategyGenerator): - def validate(self) -> bool: - input_op_data = self.op_data['input'] - other_op_data = self.op_data['other'] + input_op_data = self.op_data["input"] + other_op_data = self.op_data["other"] assert input_op_data.data.dim() == 2 and other_op_data.data.dim() == 1 def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) return compute_cost @ignore_sharding_exception @@ -133,67 +136,69 @@ def no_split(self): dim_partition_dict = {"input": {}, "other": {}, "output": {}} if self.has_bias: - dim_partition_dict['bias'] = {} + dim_partition_dict["bias"] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping={}) + return self.get_sharding_strategy( + name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={} + ) @ignore_sharding_exception def split_input_batch(self, mesh_dim): - name = f'S{mesh_dim}R = S{mesh_dim}R x R' + name = f"S{mesh_dim}R = S{mesh_dim}R x R" # get sharding spec dim_partition_dict = { - "input": { - 0: [mesh_dim] - }, + "input": {0: [mesh_dim]}, "other": {}, - "output": { - 0: [mesh_dim] - }, + "output": {0: [mesh_dim]}, } if self.has_bias: - dim_partition_dict['bias'] = {} + dim_partition_dict["bias"] = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action communication_action_mapping = {} - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=1) - communication_action_mapping['other'] = other_comm_action + arg_index=1, + ) + communication_action_mapping["other"] = other_comm_action if self.has_bias: - if self.is_param('bias'): + if self.is_param("bias"): bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=2) - communication_action_mapping['bias'] = bias_comm_action + arg_index=2, + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] @@ -209,12 +214,13 @@ def collate_strategies(self) -> List[ShardingStrategy]: class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): - - def __init__(self, - operation_data_mapping, - device_mesh, - linear_projection_type='linear', - solver_perference=SolverPerference.STANDARD): + def __init__( + self, + operation_data_mapping, + device_mesh, + linear_projection_type="linear", + solver_perference=SolverPerference.STANDARD, + ): super().__init__(operation_data_mapping, device_mesh) self.linear_projection_type = linear_projection_type self.solver_perference = solver_perference @@ -224,17 +230,17 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # C: [M, N], A: [M, P], B: [P, N] # fwd cost = MNP (only count mul) # bwd: 2 x fwd_cost - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_other_shape = strategy.sharding_specs[self.op_data["other"]].get_sharded_shape_per_device() dim_m_val = reduce(operator.mul, sharded_input_shape[:-1]) dim_n_val = sharded_other_shape[-1] dim_p_val = sharded_other_shape[0] fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=bwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=bwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) strategy.compute_cost = compute_cost def dp_strategies(self) -> List[ShardingStrategy]: @@ -301,28 +307,21 @@ def collate_strategies(self) -> List[ShardingStrategy]: @ignore_sharding_exception def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): # handle case SS = SR x RS - name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + name = f"S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}" dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - -1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0]}, + "other": {-1: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], -1: [mesh_dim_1]}, } # linear bias only has one dimension, but addmm bias has same dimensions # as the output logically. - if self.linear_projection_type == 'linear': - dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]} - elif self.linear_projection_type == 'addmm': - dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]} + if self.linear_projection_type == "linear": + dim_partition_dict_mapping["bias"] = {-1: [mesh_dim_1]} + elif self.linear_projection_type == "addmm": + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0], -1: [mesh_dim_1]} else: - raise ('Unsupported linear projection type') + raise ("Unsupported linear projection type") sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -333,75 +332,75 @@ def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) - communication_action_mapping['input'] = input_comm_action - communication_action_mapping['other'] = other_comm_action + communication_action_mapping["input"] = input_comm_action + communication_action_mapping["other"] = other_comm_action # we only add allreduce comm action for linear bias, because # allreduce comm action for addmm bias will be considered in post processing - if self.has_bias and self.linear_projection_type == 'linear': - if self.is_param('bias'): + if self.has_bias and self.linear_projection_type == "linear": + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') - communication_action_mapping['bias'] = bias_comm_action + key_for_kwarg="bias", + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # handle the case SR = SS x SR - name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + name = f"S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R" # get sharding spec mapping dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0], -1: [mesh_dim_1]}, + "other": {0: [mesh_dim_1]}, "bias": {}, - "output": { - 0: [mesh_dim_0] - }, + "output": {0: [mesh_dim_0]}, } # linear bias only has one dimension, but addmm bias has same dimensions # as the output logically. - if self.linear_projection_type == 'linear': - dim_partition_dict_mapping['bias'] = {} - elif self.linear_projection_type == 'addmm': - dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]} + if self.linear_projection_type == "linear": + dim_partition_dict_mapping["bias"] = {} + elif self.linear_projection_type == "addmm": + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]} else: - raise ('Unsupported linear projection type') + raise ("Unsupported linear projection type") sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -412,66 +411,64 @@ def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=1) + arg_index=1, + ) - communication_action_mapping['other'] = other_comm_action - communication_action_mapping['output'] = output_comm_action + communication_action_mapping["other"] = other_comm_action + communication_action_mapping["output"] = output_comm_action # we only add allreduce comm action for linear bias, because # allreduce comm action for addmm bias will be considered in post processing - if self.has_bias and self.linear_projection_type == 'linear': - if self.is_param('bias'): + if self.has_bias and self.linear_projection_type == "linear": + if self.is_param("bias"): bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - key_for_kwarg='bias') - communication_action_mapping['bias'] = bias_comm_action + key_for_kwarg="bias", + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + name = f"RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}" # get sharding specs dim_partition_dict_mapping = { - "input": { - -1: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0], - -1: [mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_1] - }, - "output": { - -1: [mesh_dim_1] - }, + "input": {-1: [mesh_dim_0]}, + "other": {0: [mesh_dim_0], -1: [mesh_dim_1]}, + "bias": {-1: [mesh_dim_1]}, + "output": {-1: [mesh_dim_1]}, } # We don't have to do anything special for bias here, because @@ -482,34 +479,34 @@ def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # get communication actions communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) communication_action_mapping["input"] = input_comm_action - communication_action_mapping['output'] = output_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping["output"] = output_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def recompute_split_both_contract(self, mesh_dim): - name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + name = f"RR = RS{mesh_dim} x S{mesh_dim}R" # get sharding spec dim_partition_dict_mapping = { - "input": { - -1: [mesh_dim] - }, - "other": { - 0: [mesh_dim] - }, + "input": {-1: [mesh_dim]}, + "other": {0: [mesh_dim]}, "bias": {}, "output": {}, } @@ -520,32 +517,29 @@ def recompute_split_both_contract(self, mesh_dim): # get communication action communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim, - comm_type=CommType.AFTER) + comm_type=CommType.AFTER, + ) - communication_action_mapping['output'] = output_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping["output"] = output_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_rhs_space_only(self, mesh_dim): - name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + name = f"RS{mesh_dim} = RR x RS{mesh_dim}" # get sharding spec dim_partition_dict_mapping = { "input": {}, - "other": { - -1: [mesh_dim] - }, - "bias": { - -1: [mesh_dim] - }, - "output": { - -1: [mesh_dim] - }, + "other": {-1: [mesh_dim]}, + "bias": {-1: [mesh_dim]}, + "output": {-1: [mesh_dim]}, } # We don't have to do anything special for bias here, because # the bias is already the same sharding spec as the output. @@ -554,93 +548,94 @@ def split_rhs_space_only(self, mesh_dim): # get communication actions communication_action_mapping = {} input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) - communication_action_mapping['input'] = input_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + communication_action_mapping["input"] = input_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' + name = f"S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR" # get sharding spec dim_partition_dict_mapping = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, "other": {}, "bias": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "output": {0: [mesh_dim_0, mesh_dim_1]}, } # linear bias only has one dimension, but addmm bias has same dimensions # as the output logically. - if self.linear_projection_type == 'linear': - dim_partition_dict_mapping['bias'] = {} - elif self.linear_projection_type == 'addmm': - dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]} + if self.linear_projection_type == "linear": + dim_partition_dict_mapping["bias"] = {} + elif self.linear_projection_type == "addmm": + dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]} else: - raise ('Unsupported linear projection type') + raise ("Unsupported linear projection type") sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action communication_action_mapping = {} - if self.is_param('other'): + if self.is_param("other"): other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=1) - communication_action_mapping['other'] = other_comm_action + arg_index=1, + ) + communication_action_mapping["other"] = other_comm_action # we only add allreduce comm action for linear bias, because # allreduce comm action for addmm bias will be considered in post processing - if self.has_bias and self.linear_projection_type == 'linear': - if self.is_param('bias'): + if self.has_bias and self.linear_projection_type == "linear": + if self.is_param("bias"): bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.HOOK, + ) else: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - key_for_kwarg='bias') - communication_action_mapping['bias'] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + key_for_kwarg="bias", + ) + communication_action_mapping["bias"] = bias_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' + name = f"RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R" # get sharding spec dim_partition_dict_mapping = { - "input": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {-1: [mesh_dim_0, mesh_dim_1]}, + "other": {0: [mesh_dim_0, mesh_dim_1]}, "bias": {}, "output": {}, } @@ -652,32 +647,29 @@ def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): # get communication action communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.AFTER) - communication_action_mapping['output'] = output_comm_action + comm_type=CommType.AFTER, + ) + communication_action_mapping["output"] = output_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): - name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' + name = f"RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}" # get sharding spec dim_partition_dict_mapping = { "input": {}, - "other": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "bias": { - -1: [mesh_dim_0, mesh_dim_1] - }, - "output": { - -1: [mesh_dim_0, mesh_dim_1] - }, + "other": {-1: [mesh_dim_0, mesh_dim_1]}, + "bias": {-1: [mesh_dim_0, mesh_dim_1]}, + "output": {-1: [mesh_dim_0, mesh_dim_1]}, } # We don't have to do anything special for bias here, because @@ -687,20 +679,23 @@ def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): # get communication action communication_action_mapping = {} input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['input'] = input_comm_action + arg_index=0, + ) + communication_action_mapping["input"] = input_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def non_split(self): - name = f'RR = RR x RR' + name = f"RR = RR x RR" # get sharding spec dim_partition_dict_mapping = { @@ -717,22 +712,24 @@ def non_split(self): # get communication action communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def validate(self) -> bool: assert "input" in self.op_data assert "other" in self.op_data # make sure the other has 2 dim - input_data = self.op_data['input'] - other_data = self.op_data['other'] + input_data = self.op_data["input"] + other_data = self.op_data["other"] assert input_data.data.dim() > 0 and other_data.data.dim() == 2 assert other_data.logical_shape[0] == input_data.logical_shape[-1] if self.has_bias: - bias_data = self.op_data['bias'] + bias_data = self.op_data["bias"] assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] @@ -757,37 +754,38 @@ def __init__(self, *args, **kwargs): def _pop_batch_dim_sharding_for_output(self, dim_partition_dict): # remove partition dict for dim 0 - dim_partition_dict['output'].pop(0, None) + dim_partition_dict["output"].pop(0, None) # decrease the remaining dim index by 1 temp_dim_partition = {} - keys = list(dim_partition_dict['output'].keys()) + keys = list(dim_partition_dict["output"].keys()) for key in keys: - val = dim_partition_dict['output'].pop(key) + val = dim_partition_dict["output"].pop(key) temp_dim_partition[key - 1] = val - dim_partition_dict['output'].update(temp_dim_partition) + dim_partition_dict["output"].update(temp_dim_partition) def validate(self) -> bool: - input_op_data = self.op_data['input'] - other_op_data = self.op_data['other'] + input_op_data = self.op_data["input"] + other_op_data = self.op_data["other"] assert len(input_op_data.logical_shape) == 3 or len(other_op_data.logical_shape) == 3 - if 'bias' in self.op_data: - bias_op_data = self.op_data['bias'] + if "bias" in self.op_data: + bias_op_data = self.op_data["bias"] assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2 def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: - fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul, - self.op_data['output'].data.shape) + fwd_compute_cost = self.op_data["input"].data.shape[-1] * reduce( + operator.mul, self.op_data["output"].data.shape + ) bwd_compute_cost = fwd_compute_cost * 2 - compute_cost = TrainCycleItem(fwd=fwd_compute_cost, - bwd=bwd_compute_cost, - total=fwd_compute_cost + bwd_compute_cost) + compute_cost = TrainCycleItem( + fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost + ) strategy.compute_cost = compute_cost @ignore_sharding_exception def split_one_batch_dim(self, mesh_dim): - name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + name = f"Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}" # get sharding_spec dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} @@ -799,30 +797,27 @@ def split_one_batch_dim(self, mesh_dim): communication_action_mapping = {} if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim, comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' + name = f"Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0, mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0, mesh_dim_1] - }, + "input": {0: [mesh_dim_0, mesh_dim_1]}, + "other": {0: [mesh_dim_0, mesh_dim_1]}, "bias": {}, - "output": { - 0: [mesh_dim_0, mesh_dim_1] - } + "output": {0: [mesh_dim_0, mesh_dim_1]}, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -832,35 +827,28 @@ def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): communication_action_mapping = {} if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' + name = f"Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0] - }, - "bias": { - 0: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - } + "input": {0: [mesh_dim_0], 1: [mesh_dim_1]}, + "other": {0: [mesh_dim_0]}, + "bias": {0: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], 1: [mesh_dim_1]}, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -869,46 +857,40 @@ def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): # get communication actions communication_action_mapping = {} other_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['other'], + sharding_spec=sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=1) - communication_action_mapping['other'] = other_comm_action + arg_index=1, + ) + communication_action_mapping["other"] = other_comm_action if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action # for addbmm case, other is the third argument instead of second. - communication_action_mapping['other'].arg_index += 1 + communication_action_mapping["other"].arg_index += 1 - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' + name = f"Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0] - }, - "other": { - 0: [mesh_dim_0], - 2: [mesh_dim_1] - }, - "bias": { - 1: [mesh_dim_1] - }, - "output": { - 0: [mesh_dim_0], - 2: [mesh_dim_1] - } + "input": {0: [mesh_dim_0]}, + "other": {0: [mesh_dim_0], 2: [mesh_dim_1]}, + "bias": {1: [mesh_dim_1]}, + "output": {0: [mesh_dim_0], 2: [mesh_dim_1]}, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -917,43 +899,41 @@ def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): # get communication actions communication_action_mapping = {} input_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['input'], + sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['input'] = input_comm_action + arg_index=0, + ) + communication_action_mapping["input"] = input_comm_action if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.BEFORE) - communication_action_mapping['bias'] = bias_comm_action + comm_type=CommType.BEFORE, + ) + communication_action_mapping["bias"] = bias_comm_action # for addbmm case, other is the second argument instead of first. - communication_action_mapping['input'].arg_index += 1 + communication_action_mapping["input"].arg_index += 1 - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) @ignore_sharding_exception def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): - name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' + name = f"Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}" dim_partition_dict = { - "input": { - 0: [mesh_dim_0], - 2: [mesh_dim_1] - }, - "other": { - 0: [mesh_dim_0], - 1: [mesh_dim_1] - }, + "input": {0: [mesh_dim_0], 2: [mesh_dim_1]}, + "other": {0: [mesh_dim_0], 1: [mesh_dim_1]}, "bias": {}, "output": { 0: [mesh_dim_0], - } + }, } if self.squeeze_batch_dim: self._pop_batch_dim_sharding_for_output(dim_partition_dict) @@ -962,24 +942,28 @@ def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): # get communication actions communication_action_mapping = {} output_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['output'], + sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1, - comm_type=CommType.AFTER) - communication_action_mapping['output'] = output_comm_action + comm_type=CommType.AFTER, + ) + communication_action_mapping["output"] = output_comm_action if self.has_bias: bias_comm_action = self.get_communication_action( - sharding_spec=sharding_spec_mapping['bias'], + sharding_spec=sharding_spec_mapping["bias"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.BEFORE, - arg_index=0) - communication_action_mapping['bias'] = bias_comm_action - - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + arg_index=0, + ) + communication_action_mapping["bias"] = bias_comm_action + + return self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py index b7db42f8f67e..b307e38b5b6d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py @@ -21,28 +21,31 @@ class NormalPoolStrategyGenerator(StrategyGenerator): """ def validate(self) -> bool: - ''' + """ In sanity check, we need make sure the input data having correct dimension size. For Pool1d, the dim of input data should be 3([N, C, L]). For Pool2d, the dim of input data should be 4([N, C, H, W]). For Pool3d, the dim of input data should be 5([N, C, H, W, D]). - ''' - input_op_data = self.op_data['input'] + """ + input_op_data = self.op_data["input"] assert input_op_data.data.dim() in ( - 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' + 3, + 4, + 5, + ), f"We suppose the dim of input fed into Pool op should in range of [3, 5]." def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: - ''' + """ Compute the computation cost per device with this specific strategy. Note: compute_cost need to be divided by TFLOPS, now it just shows the computation size. - ''' + """ # TODO: compute_cost need to be divided by TFLOPS, now it just shows the computation size. # 1D: (Lout) * N * C * kernel # 2D: (H * W) * N * Cout * Cin * kernel # 3D: (H * W * D) * N * Cout * Cin * kernel - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() kernel_size = self.op_data["other"].data if isinstance(kernel_size, int): @@ -61,8 +64,8 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -88,12 +91,16 @@ def _generate_strategy_with_dim_partition(self, dim_partition): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + name = ( + f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' + ) communication_action_mapping = {} - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py index 69d1642d4f80..33fb1ac5c5be 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py @@ -12,7 +12,7 @@ from .strategy_generator import OutputStrategyGenerator -__all__ = ['OutputGenerator'] +__all__ = ["OutputGenerator"] class OutputGenerator(OutputStrategyGenerator): @@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator): OutputGenerator is a generic class to generate strategies for Output Node. """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - predecessor_nodes: List[Node], output_option: str): + def __init__( + self, + operation_data_mapping: Dict[str, OperationData], + device_mesh: DeviceMesh, + predecessor_nodes: List[Node], + output_option: str, + ): super().__init__(operation_data_mapping, device_mesh, predecessor_nodes) self.output_option = output_option @@ -33,9 +38,9 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ fwd_mem_cost = MemoryCost(activation=0, parameter=0) bwd_mem_cost = MemoryCost(activation=0, parameter=0) @@ -65,16 +70,18 @@ def replica_strategy(self) -> List[ShardingStrategy]: else: dim_partition_dict_for_output = tuple(dim_partition_dict_for_output) - dim_partition_dict_mapping['output'] = dim_partition_dict_for_output + dim_partition_dict_mapping["output"] = dim_partition_dict_for_output communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Replica Output' + name = "Replica Output" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[ShardingStrategy]: @@ -82,19 +89,15 @@ def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[Shardi Generate distributed strategy for output node. """ # TODO: need to take care of the case when the first element of output only need to be sharded. - output_op_data = self.op_data['output'] + output_op_data = self.op_data["output"] if isinstance(output_op_data.data, tuple): length = len(output_op_data.data) dim_partition_dict_mapping = { - "output": [{ - 0: mesh_list - }] * length, + "output": [{0: mesh_list}] * length, } else: dim_partition_dict_mapping = { - "output": { - 0: mesh_list - }, + "output": {0: mesh_list}, } for index, _ in enumerate(self.predecessor_nodes): mapping_name = f"input_{index}" @@ -103,19 +106,21 @@ def distributed_strategy(self, mesh_list: List[List[int]] = None) -> List[Shardi communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Distributed Output' + name = "Distributed Output" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] mesh_list = [0, 1] - if self.output_option == 'replicated': + if self.output_option == "replicated": strategy_list.append(self.replica_strategy()) - elif self.output_option == 'distributed': + elif self.output_option == "distributed": strategy_list.append(self.distributed_strategy(mesh_list)) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py index 779a7ced93bb..df0862a396d2 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py @@ -10,7 +10,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['PlaceholderGenerator'] +__all__ = ["PlaceholderGenerator"] class PlaceholderGenerator(StrategyGenerator): @@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator): PlaceholderGenerator is a generic class to generate strategies for placeholder node. """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - placeholder_option: str): + def __init__( + self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, placeholder_option: str + ): super().__init__(operation_data_mapping, device_mesh) self.placeholder_option = placeholder_option @@ -31,10 +32,10 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + """ + forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")} # compute fwd cost incurred # fwd_cost = output @@ -58,11 +59,13 @@ def replica_placeholder(self) -> ShardingStrategy: communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Replica Placeholder' + name = "Replica Placeholder" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy @@ -71,29 +74,31 @@ def distributed_placeholder(self, mesh_list) -> ShardingStrategy: Generate distributed strategy for placeholder node. """ dim_partition_dict_mapping = { - "output": { - 0: mesh_list - }, + "output": {0: mesh_list}, } communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Distributed Placeholder' + name = "Distributed Placeholder" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] - if self.placeholder_option == 'distributed': + if self.placeholder_option == "distributed": mesh_list = [0, 1] distributed_strategy = self.distributed_placeholder(mesh_list) strategy_list.append(distributed_strategy) else: - assert self.placeholder_option == 'replicated', f'placeholder_option {self.placeholder_option} is not supported' + assert ( + self.placeholder_option == "replicated" + ), f"placeholder_option {self.placeholder_option} is not supported" replicated_strategy = self.replica_placeholder() strategy_list.append(replicated_strategy) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py index 24f75e352935..48f454553ac7 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py @@ -17,7 +17,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.sharding_spec import ShardingSpec -__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator'] +__all__ = ["ReshapeGenerator", "ViewGenerator", "PermuteGenerator", "TransposeGenerator", "SplitGenerator"] class ReshapeGenerator(FollowingStrategyGenerator): @@ -33,12 +33,12 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -56,8 +56,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -77,8 +78,8 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - origin_shape = self.op_data['input'].data.shape - tgt_shape = self.op_data['tgt_shape'].data + origin_shape = self.op_data["input"].data.shape + tgt_shape = self.op_data["tgt_shape"].data reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) @@ -86,8 +87,9 @@ def collate_strategies(self) -> List[ShardingStrategy]: keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict) if keep_sharding_status: - dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input, - reshape_mapping_dict) + dim_partition_dict_for_output = infer_output_dim_partition_dict( + dim_partition_dict_for_input, reshape_mapping_dict + ) else: dim_partition_dict_for_output = {} @@ -119,7 +121,8 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, logical_process_axis=total_mesh_dim_list, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) # it will gather the input through gather_dim during forward phase. input_comm_action.comm_spec.gather_dim = shard_dim # it will split the input activation grad through shard_dim during backward phase. @@ -127,10 +130,10 @@ def collate_strategies(self) -> List[ShardingStrategy]: elif len(total_mesh_dim_list) >= 2: source_spec = sharding_spec_mapping["input"] - target_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=source_spec.entire_shape, - dim_partition_dict={}) - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + target_spec = ShardingSpec( + device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={} + ) + comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) else: @@ -139,9 +142,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: if input_comm_action is not None: communication_action_mapping["input"] = input_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -159,7 +164,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] - permute_dims = self.op_data['permute_dims'].data + permute_dims = self.op_data["permute_dims"].data dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict dim_partition_dict_for_output = {} for dim_index, permute_dim in enumerate(permute_dims): @@ -177,9 +182,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -199,7 +206,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict dim_partition_dict_for_output = {} - transpose_dims = self.op_data['transpose_dims'].data + transpose_dims = self.op_data["transpose_dims"].data dim_0 = transpose_dims[0] dim_1 = transpose_dims[1] for dim, sharded_dims in dim_partition_dict_for_input.items(): @@ -221,9 +228,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -242,7 +251,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - split_size, split_dim = self.op_data['split_info'].data + split_size, split_dim = self.op_data["split_info"].data if split_dim in dim_partition_dict_for_input: recover_dims = dim_partition_dict_for_input.pop(split_dim) @@ -271,7 +280,8 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, logical_process_axis=recover_dims, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) # it will gather the input through gather_dim during forward phase. input_comm_action.comm_spec.gather_dim = split_dim # it will split the input activation grad through split_dim during backward phase. @@ -282,7 +292,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: source_spec = input_sharding_spec # target sharding spec target_spec = sharding_spec_mapping["input"] - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) else: @@ -291,9 +301,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: if input_comm_action is not None: communication_action_mapping["input"] = input_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list @@ -341,16 +353,17 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, logical_process_axis=total_mesh_dim_list, comm_type=CommType.BEFORE, - arg_index=0) + arg_index=0, + ) input_comm_action.comm_spec.gather_dim = total_mesh_dim_list input_comm_action.comm_spec.shard_dim = total_mesh_dim_list elif len(total_mesh_dim_list) >= 2: source_spec = sharding_spec_mapping["input"] - target_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=source_spec.entire_shape, - dim_partition_dict={}) - comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + target_spec = ShardingSpec( + device_mesh=self.device_mesh, entire_shape=source_spec.entire_shape, dim_partition_dict={} + ) + comm_spec = {"src_spec": source_spec, "tgt_spec": target_spec} input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) else: @@ -358,9 +371,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: if input_comm_action is not None: communication_action_mapping["input"] = input_comm_action - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py index a1ebadd043e2..d4382f9941d2 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py @@ -4,21 +4,9 @@ from typing import List from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.auto_parallel.tensor_shard.utils import ( - check_keep_sharding_status, - detect_reshape_mapping, - infer_output_dim_partition_dict, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern - -__all__ = ['SoftmaxGenerator'] +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem + +__all__ = ["SoftmaxGenerator"] class SoftmaxGenerator(FollowingStrategyGenerator): @@ -30,11 +18,11 @@ def validate(self) -> bool: return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the computation cost per device with this specific strategy. - ''' - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + """ + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() input_size_product = reduce(operator.mul, sharded_input_shape) output_size_product = reduce(operator.mul, sharded_output_shape) @@ -45,12 +33,12 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -68,8 +56,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -80,10 +69,10 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - softmax_dim = self.op_data['softmax_dim'].data + softmax_dim = self.op_data["softmax_dim"].data if softmax_dim in dim_partition_dict_for_input: - recover_dims = dim_partition_dict_for_input.pop(softmax_dim) + dim_partition_dict_for_input.pop(softmax_dim) dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input) dim_partition_dict_mapping = { @@ -96,9 +85,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index d42429745c61..7bf2c8cc12a3 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -39,7 +39,7 @@ def has_bias(self): """ A utility method to check for the existence of bias operand for convenience. """ - return 'bias' in self.op_data + return "bias" in self.op_data def is_param(self, op_data_name): other_data = self.op_data[op_data_name] @@ -49,8 +49,12 @@ def is_buffer(self, op_data_name): other_data = self.op_data[op_data_name] return other_data.type == OperationDataType.BUFFER - def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec], - communication_action_mapping: Dict[str, CommSpec]): + def get_sharding_strategy( + self, + name: str, + sharding_spec_mapping: Dict[str, ShardingSpec], + communication_action_mapping: Dict[str, CommSpec], + ): """ A factory method to produce a ShardingStrategy object. @@ -80,24 +84,28 @@ def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]): op_data = self.op_data[op_data_name] def _to_sharding_spec( - data: any, logical_shape: any, - dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]: + data: any, logical_shape: any, dim_partition_dict: Dict[int, List[int]] + ) -> Union[ShardingSpec, List[ShardingSpec], None]: """ This is a recursive function to convert the dim partition dict to a ShardingSpec object. """ if isinstance(data, torch.Tensor): dim_size = len(logical_shape) dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=logical_shape, - dim_partition_dict=dim_partition_dict) + sharding_spec = ShardingSpec( + device_mesh=self.device_mesh, + entire_shape=logical_shape, + dim_partition_dict=dim_partition_dict, + ) return sharding_spec elif isinstance(data, (list, tuple)): sharding_spec = [] for data_element, logical_shape_element, dim_partition_dict_element in zip( - data, logical_shape, dim_partition_dict): + data, logical_shape, dim_partition_dict + ): sharding_spec.append( - _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)) + _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element) + ) return sharding_spec else: return None @@ -116,31 +124,41 @@ def replace_op_name_with_op_data(self, mapping: Dict[str, Any]): results[op_data] = v return results - def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern, - logical_process_axis: Union[int, List[int]]): + def get_communication_spec( + self, + sharding_spec: ShardingSpec, + communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]], + ): """ A factory method to produce a CommSpec object. """ - return CommSpec(comm_pattern=communication_pattern, - sharding_spec=sharding_spec, - logical_process_axis=logical_process_axis) - - def get_communication_action(self, - sharding_spec: ShardingSpec, - communication_pattern: CollectiveCommPattern, - logical_process_axis: Union[int, List[int]], - comm_type: CommType, - arg_index: int = -1, - key_for_kwarg: any = None) -> CommAction: + return CommSpec( + comm_pattern=communication_pattern, sharding_spec=sharding_spec, logical_process_axis=logical_process_axis + ) + + def get_communication_action( + self, + sharding_spec: ShardingSpec, + communication_pattern: CollectiveCommPattern, + logical_process_axis: Union[int, List[int]], + comm_type: CommType, + arg_index: int = -1, + key_for_kwarg: any = None, + ) -> CommAction: """ A factory method to produce a CommAction object. """ - return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec, - communication_pattern=communication_pattern, - logical_process_axis=logical_process_axis), - comm_type=comm_type, - arg_index=arg_index, - key_for_kwarg=key_for_kwarg) + return CommAction( + comm_spec=self.get_communication_spec( + sharding_spec=sharding_spec, + communication_pattern=communication_pattern, + logical_process_axis=logical_process_axis, + ), + comm_type=comm_type, + arg_index=arg_index, + key_for_kwarg=key_for_kwarg, + ) def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ @@ -155,9 +173,9 @@ def _compute_and_add(op_data: OperationData, comm_spec: CommSpec): size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() for phase, cost in num_ele_in_comm.items(): num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes - comm_cost.fwd += num_ele_in_comm['forward'] - comm_cost.bwd += num_ele_in_comm['backward'] - comm_cost.total += num_ele_in_comm['total'] + comm_cost.fwd += num_ele_in_comm["forward"] + comm_cost.bwd += num_ele_in_comm["backward"] + comm_cost.total += num_ele_in_comm["total"] # check if communication action exists # if so, loop over each action and compute the cost of each action @@ -169,8 +187,8 @@ def _compute_and_add(op_data: OperationData, comm_spec: CommSpec): # this condition branch will be removed after all the handler updated. comm_spec = comm_action if isinstance(comm_spec, dict): - src_spec = comm_spec['src_spec'] - tgt_spec = comm_spec['tgt_spec'] + src_spec = comm_spec["src_spec"] + tgt_spec = comm_spec["tgt_spec"] shape_consistency_manager = ShapeConsistencyManager() _, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec) for comm_spec_ in comm_action_sequence: @@ -187,14 +205,12 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ Customize this method to compute the computation flops. """ - pass @abstractmethod def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ Customize this method to compute the memory cost in bytes. """ - pass def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str): """ @@ -212,13 +228,14 @@ def _compute_size_in_bytes_helper(sharding_spec, meta_data): num_elements = 1 else: num_elements = reduce(operator.mul, sharded_shape) - dtype = getattr(meta_data, 'dtype') + dtype = getattr(meta_data, "dtype") size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() return num_elements * size_per_elem_bytes if isinstance(op_data.data, tuple): - assert isinstance(strategy.sharding_specs[op_data], list), \ - 'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.' + assert isinstance( + strategy.sharding_specs[op_data], list + ), "sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple." total_bytes = 0 for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]): meta_data = op_data.data[index] @@ -270,7 +287,6 @@ def validate(self) -> bool: Validate if the operands are of desired shape. If True, means this generator can be used for the current operation. """ - pass class FollowingStrategyGenerator(StrategyGenerator): @@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator): TODO: remove the original strategy_generator.py after refactoring """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - predecessor_node: Node): + def __init__( + self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_node: Node + ): self.op_data = operation_data_mapping self.device_mesh = device_mesh self.predecessor_node = predecessor_node @@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator): OutputStrategyGenerator is used to generate the sharding strategies for Output Node. """ - def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, - predecessor_nodes: List[Node]): + def __init__( + self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node] + ): super().__init__(operation_data_mapping, device_mesh) self.predecessor_nodes = predecessor_nodes diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py index a0fbc58d70c0..dcbf34cfd65b 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py @@ -4,22 +4,9 @@ from typing import List from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.auto_parallel.tensor_shard.utils import ( - check_keep_sharding_status, - detect_reshape_mapping, - infer_output_dim_partition_dict, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from colossalai.tensor.sharding_spec import ShardingSpec - -__all__ = ['SumGenerator'] +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem + +__all__ = ["SumGenerator"] class SumGenerator(FollowingStrategyGenerator): @@ -31,24 +18,24 @@ def validate(self) -> bool: return super().validate() def update_compute_cost(self, strategy: ShardingStrategy): - sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() - sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device() + sharded_input_shape = strategy.sharding_specs[self.op_data["input"]].get_sharded_shape_per_device() + sharded_output_shape = strategy.sharding_specs[self.op_data["output"]].get_sharded_shape_per_device() input_size_product = reduce(operator.mul, sharded_input_shape) output_size_product = reduce(operator.mul, sharded_output_shape) - compute_cost = TrainCycleItem(fwd=input_size_product, - bwd=output_size_product, - total=input_size_product + output_size_product) + compute_cost = TrainCycleItem( + fwd=input_size_product, bwd=output_size_product, total=input_size_product + output_size_product + ) strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -66,8 +53,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -78,7 +66,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict) - sum_dims, sum_mapping_dict = self.op_data['sum_info'].data + sum_dims, sum_mapping_dict = self.op_data["sum_info"].data # TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce # among all the shard groups @@ -90,7 +78,7 @@ def collate_strategies(self) -> List[ShardingStrategy]: elif dim in sum_mapping_dict: dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim] else: - raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims') + raise RuntimeError(f"dim {dim} is not in sum_mapping_dict or sum_dims") for dim in recover_dims: dim_partition_dict_for_input.pop(dim) @@ -105,9 +93,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py index 93cfc9eeea53..eea00c2fa064 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py @@ -1,19 +1,10 @@ -import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - CommAction, - CommType, - MemoryCost, - ShardingStrategy, - TrainCycleItem, -) -from colossalai.tensor.shape_consistency import CollectiveCommPattern -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from .strategy_generator import StrategyGenerator -__all__ = ['TensorConstructorGenerator'] +__all__ = ["TensorConstructorGenerator"] class TensorConstructorGenerator(StrategyGenerator): @@ -30,10 +21,10 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' - forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")} + """ + forward_size_mapping = {"output": self._compute_size_in_bytes(strategy, "output")} # compute fwd cost incurred # fwd_cost = input + output @@ -57,11 +48,13 @@ def collate_strategies(self) -> List[ShardingStrategy]: communication_action_mapping = {} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) - name = 'Replica Tensor Constructor' + name = "Replica Tensor Constructor" - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py index 39799a67c5a0..943cf3f1f50d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py @@ -5,7 +5,7 @@ from .strategy_generator import FollowingStrategyGenerator -__all__ = ['UnaryElementwiseGenerator'] +__all__ = ["UnaryElementwiseGenerator"] class UnaryElementwiseGenerator(FollowingStrategyGenerator): @@ -21,12 +21,12 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'input': self._compute_size_in_bytes(strategy, "input"), - 'output': self._compute_size_in_bytes(strategy, "output") + "input": self._compute_size_in_bytes(strategy, "input"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -44,8 +44,9 @@ def update_memory_cost(self, strategy: ShardingStrategy): bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) # compute total cost - total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, - parameter=fwd_parameter_cost + bwd_parameter_cost) + total_mem_cost = MemoryCost( + activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + bwd_parameter_cost + ) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost @@ -69,9 +70,11 @@ def collate_strategies(self) -> List[ShardingStrategy]: # we keep same strategies with different name for node merging, and it will not increase the searching space, # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) strategy_list.append(strategy) return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py index fa941f2cc51d..b27b4f3d4056 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py @@ -10,7 +10,7 @@ from .strategy_generator import StrategyGenerator -__all__ = ['WhereGenerator'] +__all__ = ["WhereGenerator"] class WhereGenerator(StrategyGenerator): @@ -26,14 +26,14 @@ def update_compute_cost(self, strategy: ShardingStrategy): strategy.compute_cost = compute_cost def update_memory_cost(self, strategy: ShardingStrategy): - ''' + """ Compute the memory cost per device with this specific strategy. - ''' + """ forward_size_mapping = { - 'condition': self._compute_size_in_bytes(strategy, "condition"), - 'x': self._compute_size_in_bytes(strategy, "x"), - 'y': self._compute_size_in_bytes(strategy, "y"), - 'output': self._compute_size_in_bytes(strategy, "output") + "condition": self._compute_size_in_bytes(strategy, "condition"), + "x": self._compute_size_in_bytes(strategy, "x"), + "y": self._compute_size_in_bytes(strategy, "y"), + "output": self._compute_size_in_bytes(strategy, "output"), } backward_size_mapping = copy.deepcopy(forward_size_mapping) @@ -59,7 +59,7 @@ def _generate_strategy_with_dim_partition(self, dim_partition): "condition": dim_partition, "x": dim_partition, "y": dim_partition, - "output": dim_partition + "output": dim_partition, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) @@ -67,9 +67,11 @@ def _generate_strategy_with_dim_partition(self, dim_partition): name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}' communication_action_mapping = {} - strategy = self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) + strategy = self.get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping, + ) return strategy @@ -84,9 +86,9 @@ def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_l return dim_partition_list def collate_strategies(self) -> List[ShardingStrategy]: - ''' + """ Generate every possible strategies for a where node, and record all strategies into the strategies_vector. - ''' + """ strategy_list = [] dimension_length = len(self.op_data["output"].logical_shape) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py index 86f90694e060..5b4ea0afe5f8 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/sum_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import StrategyGenerator, SumGenerator -__all__ = ['SumHandler'] +__all__ = ["SumHandler"] @operator_registry.register(torch.Tensor.sum) @@ -55,7 +55,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: # sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input # sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input sum_mapping_dict = {} - if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']: + if "keepdim" in self.node.kwargs and self.node.kwargs["keepdim"]: for i in range(num_dims): sum_mapping_dict.update({i: i}) else: @@ -67,7 +67,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: assert output_index == self.node._meta_data.dim() sum_info = (sum_dims, sum_mapping_dict) - physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info) + physical_shape_operand = OperationData(name="sum_info", type=OperationDataType.ARG, data=sum_info) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -75,7 +75,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "sum_info": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py index 855a2e7612af..c2aa120e8a28 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py @@ -8,7 +8,7 @@ from .strategy import StrategyGenerator from .strategy.tensor_constructor_generator import TensorConstructorGenerator -__all__ = ['TensorConstructorHandler'] +__all__ = ["TensorConstructorHandler"] @operator_registry.register(torch.arange) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py index 7a9d37726490..b72d9812f406 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import StrategyGenerator, TransposeGenerator -__all__ = ['TransposeHandler'] +__all__ = ["TransposeHandler"] @operator_registry.register(torch.Tensor.transpose) @@ -48,9 +48,9 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: if transpose_dims[i] < 0: transpose_dims[i] += num_dims - physical_shape_operand = OperationData(name='transpose_dims', - type=OperationDataType.ARG, - data=list(transpose_dims)) + physical_shape_operand = OperationData( + name="transpose_dims", type=OperationDataType.ARG, data=list(transpose_dims) + ) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -58,7 +58,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "transpose_dims": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py index 0362de780d7a..cbc873de8223 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py @@ -3,11 +3,11 @@ import torch from ..sharding_strategy import OperationData, OperationDataType -from .node_handler import MetaInfoNodeHandler, NodeHandler +from .node_handler import MetaInfoNodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, UnaryElementwiseGenerator -__all__ = ['UnaryElementwiseHandler'] +__all__ = ["UnaryElementwiseHandler"] @operator_registry.register(torch.Tensor.to) @@ -33,9 +33,9 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_input_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) + physical_input_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) mapping = {"input": physical_input_operand, "output": physical_output} diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py index 7dff89d1d7a3..56c1d10a167e 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py @@ -7,7 +7,7 @@ from .registry import operator_registry from .strategy import StrategyGenerator, ViewGenerator -__all__ = ['ViewHandler'] +__all__ = ["ViewHandler"] @operator_registry.register(torch.Tensor.reshape) @@ -38,7 +38,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) target_shape = self.node._meta_data.shape - physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape) + physical_shape_operand = OperationData(name="tgt_shape", type=OperationDataType.ARG, data=target_shape) output_data = self.node._meta_data physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) @@ -46,7 +46,7 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]: mapping = { "input": physical_input_operand, "tgt_shape": physical_shape_operand, - "output": physical_output_operand + "output": physical_output_operand, } return mapping diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py index 6de2aaafdd01..1856a11100b0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py @@ -1,16 +1,15 @@ import copy -import operator from typing import Dict, List import torch -from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector +from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy from ..utils import recover_sharding_spec_for_broadcast_shape from .node_handler import NodeHandler from .registry import operator_registry from .strategy import StrategyGenerator, WhereGenerator -__all__ = ['WhereHandler'] +__all__ = ["WhereHandler"] @operator_registry.register(torch.where) @@ -28,27 +27,28 @@ def get_strategy_generator(self) -> List[StrategyGenerator]: def get_operation_data_mapping(self) -> Dict[str, OperationData]: # use transposed shape for strategies # the strategies will be transformed back to its original shape in self.post_process - physical_condition_operand = OperationData(name=str(self.node.args[0]), - type=OperationDataType.ARG, - data=self.node.args[0]._meta_data) - physical_x_operand = OperationData(name=str(self.node.args[1]), - type=OperationDataType.ARG, - data=self.node.args[1]._meta_data) - physical_y_operand = OperationData(name=str(self.node.args[2]), - type=OperationDataType.ARG, - data=self.node.args[2]._meta_data) + physical_condition_operand = OperationData( + name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data + ) + physical_x_operand = OperationData( + name=str(self.node.args[1]), type=OperationDataType.ARG, data=self.node.args[1]._meta_data + ) + physical_y_operand = OperationData( + name=str(self.node.args[2]), type=OperationDataType.ARG, data=self.node.args[2]._meta_data + ) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_mapping = { "condition": physical_condition_operand, "x": physical_x_operand, "y": physical_y_operand, - "output": physical_output + "output": physical_output, } logical_shape_for_all = self.node._meta_data.shape logical_mapping = {} for key, physical_operand in physical_mapping.items(): - logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand, - logical_shape_for_all) + logical_mapping[key] = self.convert_physical_operand_to_logical_operand( + physical_operand, logical_shape_for_all + ) return logical_mapping, physical_mapping @@ -64,7 +64,8 @@ def post_process(self, strategy: ShardingStrategy): logical_shape = logical_op_data_mapping[key].logical_shape physical_shape = physical_op_data_mapping[key].logical_shape physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape( - logical_sharding_spec, logical_shape, physical_shape) + logical_sharding_spec, logical_shape, physical_shape + ) strategy.sharding_specs.pop(logical_op_data_mapping[key]) strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py index f0ea502a6f0e..e87872f39c10 100644 --- a/colossalai/auto_parallel/tensor_shard/options.py +++ b/colossalai/auto_parallel/tensor_shard/options.py @@ -1,13 +1,14 @@ from dataclasses import dataclass from enum import Enum -__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption'] +__all__ = ["SolverOptions", "SolverPerference", "DataloaderOption", "ShardOption"] class SolverPerference(Enum): """ This enum class is to define the solver preference. """ + STANDARD = 0 DP = 1 TP = 2 @@ -25,6 +26,7 @@ class ShardOption(Enum): TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis. TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes. """ + STANDARD = 0 SHARD = 1 SHARD_LAST_AXIS = 2 @@ -35,6 +37,7 @@ class DataloaderOption(Enum): """ This enum class is to define the dataloader option. """ + REPLICATED = 0 DISTRIBUTED = 1 @@ -44,6 +47,7 @@ class SolverOptions: """ SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. """ + solver_perference: SolverPerference = SolverPerference.STANDARD dataloader_option: DataloaderOption = DataloaderOption.REPLICATED shard_option: ShardOption = ShardOption.STANDARD diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index 6af927272437..8e22df64d868 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -10,7 +10,6 @@ from colossalai.tensor.sharding_spec import ShardingSpec from .constants import ( - BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_METHOD_OP, ELEMENTWISE_MODULE_OP, @@ -18,13 +17,14 @@ RESHAPE_METHOD_OP, ) -__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] +__all__ = ["OperationDataType", "OperationData", "TrainCycleItem", "MemoryCost", "ShardingStrategy", "StrategiesVector"] class OperationDataType(Enum): """ An operation can come from the argument list of an operator or the parameter list of a module. """ + INPUT = 0 ARG = 1 PARAM = 2 @@ -43,6 +43,7 @@ class OperationData: data (Any): the value for this data, usually it is a meta tensor. logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory. """ + name: str type: OperationDataType data: Any @@ -69,13 +70,13 @@ def _infer_logical_shape(data: any): self.logical_shape = _infer_logical_shape(self.data) def __repr__(self) -> str: - return f'OperationData(name={self.name}, type={self.type})' + return f"OperationData(name={self.name}, type={self.type})" def __eq__(self, other) -> bool: return other.name == self.name def __hash__(self) -> int: - return hash(f'{self.name}') + return hash(f"{self.name}") @dataclass @@ -88,6 +89,7 @@ class TrainCycleItem: fwd (float): the item for the forward pass bwd (float): the item for the backward pass """ + fwd: Any bwd: Any total: Any @@ -104,6 +106,7 @@ class MemoryCost: temp (int): the memory cost incurred by the temporary tensors in bytes. buffer (int): the memory cost incurred by the module buffer in bytes. """ + activation: int = 0 parameter: int = 0 temp: int = 0 @@ -120,6 +123,7 @@ class CommType(Enum): HOOK: the communication action is used to do the grad all reduce. IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm """ + BEFORE = 0 AFTER = 1 HOOK = 2 @@ -137,6 +141,7 @@ class CommAction: arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime, because the args of node may be changed by graph transform passes. """ + comm_spec: CommSpec = None comm_type: CommType = None arg_index: int = -1 @@ -156,6 +161,7 @@ class ShardingStrategy: memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None) input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes. """ + name: str sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None compute_cost: TrainCycleItem = None @@ -200,7 +206,6 @@ def get_sharding_spec_by_name(self, name: str): raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}") def clone(self): - def _deepcopy_dict_vals(data: Dict): return {k: deepcopy(v) for k, v in data.items()} @@ -209,31 +214,34 @@ def _deepcopy_dict_vals(data: Dict): # Consider the examples below: # If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False. # In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items. - communication_actions = _deepcopy_dict_vals( - self.communication_actions) if self.communication_actions is not None else None + communication_actions = ( + _deepcopy_dict_vals(self.communication_actions) if self.communication_actions is not None else None + ) # same reason as communication_actions resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None compute_cost = deepcopy(self.compute_cost) communication_cost = deepcopy(self.communication_cost) memory_cost = deepcopy(self.memory_cost) - return ShardingStrategy(name=self.name, - sharding_specs=sharding_specs, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - communication_actions=communication_actions, - resharding_costs=resharding_costs) + return ShardingStrategy( + name=self.name, + sharding_specs=sharding_specs, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + communication_actions=communication_actions, + resharding_costs=resharding_costs, + ) class StrategiesVector(list): - ''' + """ Each node in fx graph will have a corresponding StrategiesVector, to store all the possible strategies of the node. Argument: node (Node): node for which the list of sharding strategies are generated. - ''' + """ def __init__(self, node: Node): super().__init__() @@ -245,7 +253,7 @@ def __init__(self, node: Node): def check_merge(self): merge_label = False - if self.node.op == 'call_module': + if self.node.op == "call_module": target = self.node.target root_module = self.node.graph.owning_module submod = root_module.get_submodule(target) @@ -255,7 +263,7 @@ def check_merge(self): if submod_type in ELEMENTWISE_MODULE_OP: merge_label = True - if self.node.op == 'call_function': + if self.node.op == "call_function": # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. if self.node.target in ELEMENTWISE_FUNC_OP: merge_label = True @@ -267,7 +275,7 @@ def check_merge(self): if self.node.target in RESHAPE_FUNC_OP: merge_label = True - if self.node.op == 'call_method': + if self.node.op == "call_method": # we could merge reshape op, because their computation costs are negligible. method = getattr(self.node.args[0]._meta_data.__class__, self.node.target) if method in RESHAPE_METHOD_OP: diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py index f9e6bd923921..b930ce80a9b9 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py @@ -3,4 +3,4 @@ from .solver import Solver from .strategies_constructor import StrategiesConstructor -__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph'] +__all__ = ["GraphAnalyser", "Solver", "StrategiesConstructor", "CostGraph"] diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py index 1b2d3ad57407..4415d429b0c2 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py +++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py @@ -4,7 +4,7 @@ class CostGraph: - ''' + """ A graph data structure to simplify the edge cost graph. It has two main functions: 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. @@ -15,7 +15,7 @@ class CostGraph: Argument: leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) - ''' + """ def __init__(self, leaf_strategies, simplify=True, forward_only=False): self.leaf_strategies = leaf_strategies @@ -39,10 +39,10 @@ def _remove_invalid_node(self, node, attr_name): target_node_list.remove(element) def _build_cost_graph(self): - ''' + """ This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be set to node. - ''' + """ self.edge_costs = {} if self.simplify: self.merge_pair = [] @@ -84,8 +84,8 @@ def _check_tensor_in_node(data): if _check_tensor_in_node(node._meta_data): children_nodes.append(node) - setattr(dst_node, 'parents', parent_nodes) - setattr(dst_node, 'children', children_nodes) + setattr(dst_node, "parents", parent_nodes) + setattr(dst_node, "children", children_nodes) if self.simplify and strategies_vector.check_merge(): for followed_node in strategies_vector.predecessor_nodes: @@ -99,7 +99,7 @@ def get_edge_cost(self, src_node, dst_node): return self.edge_costs[(src_node, dst_node)] def merge_node(self, src_node, dst_node): - ''' + """ To merge dst_node into src_node, we need to do it in following steps: 1. For each strategy in dst_node, we need to pick an appropriate strategy @@ -119,7 +119,7 @@ def merge_node(self, src_node, dst_node): Argument: src_node(Node): The node will be merged into dst_node. dst_node(Node): The node to integrate src_node. - ''' + """ # build merge_map merge_map = {} for src_index, _ in enumerate(src_node.strategies_vector): @@ -196,7 +196,7 @@ def simplify_graph(self): if not self.simplify: return self.merge_pair.reverse() - for (src_node, dst_node) in self.merge_pair: + for src_node, dst_node in self.merge_pair: self.merge_node(src_node, dst_node) self.merge_pair.reverse() reindexing_following_dict = {} diff --git a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py index 171aa8b3399f..678965d663e4 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py +++ b/colossalai/auto_parallel/tensor_shard/solver/graph_analysis.py @@ -7,7 +7,7 @@ from colossalai.fx.passes.utils import get_node_module -__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] +__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"] @dataclass @@ -15,6 +15,7 @@ class LiveVariable: """ LiveVariable is a data structure to store the meta information of a variable for liveness analysis. """ + name: str node: Node is_inplace: bool @@ -55,6 +56,7 @@ class LiveStage: """ LiveStage is a data structure to record the living variables at this current node. """ + name: str node: Node all_live_vars: LiveVariableVector @@ -62,7 +64,6 @@ class LiveStage: class GraphAnalyser: - def __init__(self, gm: GraphModule): self._gm = gm self._graph = gm.graph @@ -105,18 +106,18 @@ def liveness_analysis(self) -> List[LiveStage]: # detect whether the current op is an in-place op # if it is an in-place op, we would deem it as a duplicate var is_inplace = False - if node.op == 'call_function': + if node.op == "call_function": # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) - if node.kwargs.get('inplace', False): + if node.kwargs.get("inplace", False): is_inplace = True - elif node.op == 'call_module': + elif node.op == "call_module": # to check if this is an inplace op such as torch.nn.Relu(inplace=True) module = get_node_module(node) - if getattr(module, 'inplace', False): + if getattr(module, "inplace", False): is_inplace = True # add the output var - meta = getattr(node, '_meta_data', None) + getattr(node, "_meta_data", None) live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) if not is_inplace: unique_live_vars.append(live_var) @@ -138,10 +139,12 @@ def liveness_analysis(self) -> List[LiveStage]: # this should be completed if we are able to trace the backward compute graph # add this stage to liveness dict - stage = LiveStage(name=node.name, - node=node, - all_live_vars=all_live_variables.copy(), - unique_live_vars=unique_live_vars.copy()) + stage = LiveStage( + name=node.name, + node=node, + all_live_vars=all_live_variables.copy(), + unique_live_vars=unique_live_vars.copy(), + ) # if a LiveStage is covered by another LiveStage, we just keep the larger one. replace = False for index, prev_stage in enumerate(liveness_list): diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py index 564c5f09220c..088d1acb5177 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/solver.py +++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py @@ -21,24 +21,25 @@ import pulp from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum except: - warnings.warn(f'please install the pulp') + warnings.warn(f"please install the pulp") -__all___ = ['Solver'] +__all___ = ["Solver"] class Solver: - - def __init__(self, - graph: Graph, - strategies_constructor: StrategiesConstructor, - cost_graph: CostGraph, - graph_analyser: GraphAnalyser = None, - memory_budget: float = -1.0, - solution_numbers: int = 1, - forward_only: bool = False, - memory_increasing_coefficient: float = 1.3, - verbose=False): - ''' + def __init__( + self, + graph: Graph, + strategies_constructor: StrategiesConstructor, + cost_graph: CostGraph, + graph_analyser: GraphAnalyser = None, + memory_budget: float = -1.0, + solution_numbers: int = 1, + forward_only: bool = False, + memory_increasing_coefficient: float = 1.3, + verbose=False, + ): + """ Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Argument: graph: The computing graph to be optimized. @@ -48,7 +49,7 @@ def __init__(self, memory_budget: Memory constraint for the solution. solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. - ''' + """ self.graph = graph self.strategies_constructor = strategies_constructor self.cost_graph = cost_graph @@ -75,11 +76,11 @@ def __init__(self, self.verbose = verbose def _recover_merged_node_strategy(self): - ''' + """ During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged node. - ''' + """ for node_index, node in enumerate(self.nodes): if node.strategies_vector.check_merge(): # the merged node has only one input, and its strategies follow the input sharding strategy @@ -98,9 +99,9 @@ def _generate_node_index_dict(self) -> Dict[Node, int]: return node_index_dict def _prepare_data_for_solver(self): - ''' + """ Extract information from components for solver. - ''' + """ node_nums = len(self.leaf_strategies) memory_budget = self.memory_budget @@ -190,23 +191,40 @@ def _prepare_data_for_solver(self): # omit initial value for nodes s_init_np = None - return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose - - def _call_solver_serialized_args(self, - node_nums, - memory_budget, - strategies_len, - following_nodes, - edge_pairs, - alias_set, - liveness_set, - compute_costs, - communication_costs, - memory_costs, - resharding_costs, - alias_convert_costs, - s_init_np=None, - verbose=True): + return ( + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np, + self.verbose, + ) + + def _call_solver_serialized_args( + self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np=None, + verbose=True, + ): """ Call the solver with serialized arguments. """ @@ -235,18 +253,18 @@ def get_non_zero_index(binary_vector): s_follow = following_nodes s_alias = alias_set - E = edge_pairs.reshape((-1, 2)) # noqa + E = edge_pairs.reshape((-1, 2)) # noqa r = [] pt = 0 edge_set = set() - for (i, j) in E: + for i, j in E: prod_length = strategies_len[i] * strategies_len[j] if (i, j) in edge_set: raise ValueError(f"Duplicated edges: {(i, j)}") edge_set.add((i, j)) - r.append(resharding_costs[pt:pt + prod_length]) + r.append(resharding_costs[pt : pt + prod_length]) pt += prod_length assert pt == len(resharding_costs) @@ -268,7 +286,6 @@ def get_non_zero_index(binary_vector): # L.append(liveness_set[pt:pt + length]) # pt += length # assert pt == len(liveness_set) - v = [] pt = 0 c = [] @@ -277,9 +294,9 @@ def get_non_zero_index(binary_vector): pt = 0 for i in range(node_nums): length = strategies_len[i] - c.append(compute_costs[pt:pt + length]) - d.append(communication_costs[pt:pt + length]) - m.append(memory_costs[pt:pt + length]) + c.append(compute_costs[pt : pt + length]) + d.append(communication_costs[pt : pt + length]) + m.append(memory_costs[pt : pt + length]) pt += length assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" @@ -319,7 +336,7 @@ def get_non_zero_index(binary_vector): e = [] num_edges = 0 map_edge_to_idx = {} - for (idx, (i, j)) in enumerate(E): + for idx, (i, j) in enumerate(E): if len(s[i]) == 1: e.append(s[j]) elif len(s[j]) == 1: @@ -340,7 +357,7 @@ def get_non_zero_index(binary_vector): ###################################### if s_init_np is not None: s_init = s_init_np.reshape((-1, 3)) - for (idx, value, fix) in s_init: + for idx, value, fix in s_init: for i in range(len(s[idx])): s[idx][i].setInitialValue(i == value) if fix: @@ -393,7 +410,7 @@ def get_non_zero_index(binary_vector): # (d). specified by `cat="Binary"` - for (idx, (i, j)) in enumerate(E): + for idx, (i, j) in enumerate(E): if strategies_len[i] == 1 or strategies_len[j] == 1: continue @@ -402,13 +419,13 @@ def get_non_zero_index(binary_vector): # (f) for row in range(len(s[i])): - C = len(s[j]) # noqa + C = len(s[j]) # noqa prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] # (g) for col in range(len(s[j])): - R = len(s[i]) # noqa - C = len(s[j]) # noqa + R = len(s[i]) # noqa + C = len(s[j]) # noqa prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] # (h) @@ -434,7 +451,8 @@ def get_non_zero_index(binary_vector): msg = verbose time_limit = 600 assert "COIN_CMD" in pulp.listSolvers( - onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") + onlyAvailable=True + ), "Please install ILP solvers by 'sudo apt install coinor-cbc'" solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) @@ -444,13 +462,13 @@ def get_non_zero_index(binary_vector): objective = pulp.value(prob.objective) objective = float(objective) if objective is not None else -1.0 if verbose: - print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" - f"Time: {time.time() - tic}") + print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" f"Time: {time.time() - tic}") print(f"#nodes: {num_nodes}, #edges: {num_edges}") if prob.status in [pulp.LpStatusInfeasible]: - raise RuntimeError("Cannot run the function under the given memory budget. " - "Please increase the memory budget.") + raise RuntimeError( + "Cannot run the function under the given memory budget. " "Please increase the memory budget." + ) # Get and check results s_val = np.full((node_nums,), -1, dtype=np.int32) @@ -458,7 +476,7 @@ def get_non_zero_index(binary_vector): s_val[i] = get_non_zero_index(s[i]) e_val = np.full((len(E),), -1, dtype=np.int32) - for (idx, (i, j)) in enumerate(E): + for idx, (i, j) in enumerate(E): e_val[idx] = get_non_zero_index(e[idx]) i_spec_index = e_val[idx] // len(s[j]) j_spec_index = e_val[idx] % len(s[j]) diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 044a8ac847ea..aa87ee9bf3db 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -1,11 +1,5 @@ -import builtins -import math -import operator -from copy import deepcopy -from typing import Dict, List - import torch -from torch.fx import Graph, Node +from torch.fx import Graph from colossalai.auto_parallel.tensor_shard.node_handler import ( GetattrHandler, @@ -14,13 +8,12 @@ operator_registry, ) from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector -from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks from colossalai.device.device_mesh import DeviceMesh from ..options import DataloaderOption, SolverOptions -__all__ = ['StrategiesConstructor'] +__all__ = ["StrategiesConstructor"] class StrategiesConstructor: @@ -35,7 +28,7 @@ class StrategiesConstructor: def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): self.graph = graph - assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + assert graph.owning_module is not None, "The given graph is not associated with a owning_module" self.root_module = self.graph.owning_module self.nodes = list(graph.nodes) self.device_mesh = device_mesh @@ -46,11 +39,11 @@ def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: Solver self.alias_set = None def remove_duplicated_strategy(self, strategies_vector): - ''' + """ In build_strategies_and_cost method, we may produce some duplicated strategies. In this method, we will remove the duplicated strategies depending on the strategies name. Note that this operation is in-place. - ''' + """ name_checklist = [] remove_list = [] for strategy in strategies_vector: @@ -62,7 +55,6 @@ def remove_duplicated_strategy(self, strategies_vector): strategies_vector.remove(strategy) def generate_alias_set(self): - node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies] common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10) @@ -83,7 +75,7 @@ def build_strategies_and_cost(self): """ def _check_no_strategy_for_node(node): - if node.op in ('placeholder', 'get_attr', 'output'): + if node.op in ("placeholder", "get_attr", "output"): return False def _check_no_strategy_for_data(data): @@ -102,83 +94,93 @@ def _check_no_strategy_for_data(data): if _check_no_strategy_for_node(node): self.no_strategy_nodes.append(node) - pass # placeholder node - elif node.op == 'placeholder': + elif node.op == "placeholder": if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: - placeholder_option = 'distributed' + placeholder_option = "distributed" else: - assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' - placeholder_option = 'replicated' - placeholder_handler = PlaceholderHandler(node, - self.device_mesh, - strategies_vector, - placeholder_option=placeholder_option) + assert ( + self.solver_options.dataloader_option == DataloaderOption.REPLICATED + ), f"placeholder_option {self.solver_options.dataloader_option} is not supported" + placeholder_option = "replicated" + placeholder_handler = PlaceholderHandler( + node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option + ) placeholder_handler.register_strategy() # get_attr node - elif node.op == 'get_attr': - getattr_handler = GetattrHandler(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + elif node.op == "get_attr": + getattr_handler = GetattrHandler( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) getattr_handler.register_strategy() # call_module node - elif node.op == 'call_module': + elif node.op == "call_module": target = node.target submod = self.root_module.get_submodule(target) submod_type = type(submod) - handler = operator_registry.get(submod_type)(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + handler = operator_registry.get(submod_type)( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) handler.register_strategy() # attach strategies_info to node - if hasattr(handler, 'strategies_info'): - setattr(node, 'strategies_info', handler.strategies_info) + if hasattr(handler, "strategies_info"): + setattr(node, "strategies_info", handler.strategies_info) # call_function node - elif node.op == 'call_function': + elif node.op == "call_function": target = node.target - handler = operator_registry.get(target)(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + handler = operator_registry.get(target)( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) handler.register_strategy() # attach strategies_info to node - if hasattr(handler, 'strategies_info'): - setattr(node, 'strategies_info', handler.strategies_info) + if hasattr(handler, "strategies_info"): + setattr(node, "strategies_info", handler.strategies_info) # call_method node - elif node.op == 'call_method': + elif node.op == "call_method": method = getattr(node.args[0]._meta_data.__class__, node.target) - handler = operator_registry.get(method)(node, - self.device_mesh, - strategies_vector, - shard_option=self.solver_options.shard_option, - solver_perference=self.solver_options.solver_perference) + handler = operator_registry.get(method)( + node, + self.device_mesh, + strategies_vector, + shard_option=self.solver_options.shard_option, + solver_perference=self.solver_options.solver_perference, + ) handler.register_strategy() # attach strategies_info to node - if hasattr(handler, 'strategies_info'): - setattr(node, 'strategies_info', handler.strategies_info) + if hasattr(handler, "strategies_info"): + setattr(node, "strategies_info", handler.strategies_info) # output node - elif node.op == 'output': + elif node.op == "output": if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: - output_option = 'distributed' + output_option = "distributed" else: - assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' - output_option = 'replicated' + assert ( + self.solver_options.dataloader_option == DataloaderOption.REPLICATED + ), f"placeholder_option {self.solver_options.dataloader_option} is not supported" + output_option = "replicated" output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler.register_strategy() self.remove_duplicated_strategy(strategies_vector) - setattr(node, 'strategies_vector', strategies_vector) + setattr(node, "strategies_vector", strategies_vector) self.leaf_strategies.append(strategies_vector) self.strategy_map[node] = strategies_vector diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index b7fe5430bf13..d61cfd2add15 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -17,9 +17,21 @@ ) __all__ = [ - 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', - 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' - 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map', - 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict' + "BroadcastType", + "get_broadcast_shape", + "is_broadcastable", + "recover_sharding_spec_for_broadcast_shape", + "generate_resharding_costs", + "generate_sharding_spec", + "ignore_sharding_exception", + "check_sharding_spec_validity" "transpose_partition_dim", + "update_partition_dim", + "enumerate_all_possible_1d_sharding", + "enumerate_all_possible_2d_sharding", + "generate_sharding_size", + "comm_actions_for_oprands", + "pytree_map", + "detect_reshape_mapping", + "check_keep_sharding_status", + "infer_output_dim_partition_dict", ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py index 307348ea1eaf..99d5a0f2a942 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/broadcast.py +++ b/colossalai/auto_parallel/tensor_shard/utils/broadcast.py @@ -14,8 +14,11 @@ from colossalai.tensor.sharding_spec import ShardingSpec __all__ = [ - 'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape', - 'comm_actions_for_oprands' + "BroadcastType", + "is_broadcastable", + "get_broadcast_shape", + "recover_sharding_spec_for_broadcast_shape", + "comm_actions_for_oprands", ] @@ -41,7 +44,7 @@ def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]: """ Compute the broadcast shape given two shapes. """ - assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable' + assert is_broadcastable(shape1, shape2), f"{shape1} and {shape2} are not broadcastable" shape1_reverse = shape1[::-1] shape2_reverse = shape2[::-1] min_common_dim = min(len(shape1), len(shape2)) @@ -60,8 +63,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape): logical_num_dims = len(logical_shape) physical_num_dims = len(physical_shape) - assert logical_num_dims >= physical_num_dims, \ - 'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!' + assert ( + logical_num_dims >= physical_num_dims + ), "The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!" # track the dim and its broadcasting type logical_dim_broadcast_info = {} @@ -85,8 +89,9 @@ def get_broadcast_dim_info(logical_shape, physical_shape): return logical_dim_broadcast_info -def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, - physical_shape: torch.Size) -> ShardingSpec: +def recover_sharding_spec_for_broadcast_shape( + logical_sharding_spec: ShardingSpec, logical_shape: torch.Size, physical_shape: torch.Size +) -> ShardingSpec: """ This function computes the sharding spec for the physical shape of a broadcast tensor. @@ -124,15 +129,18 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe physical_dim = physical_num_dims - (logical_num_dims - shape_dim) physical_dim_partition[physical_dim] = mesh_dim - physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh, - entire_shape=physical_shape, - dim_partition_dict=physical_dim_partition) + physical_sharding_spec = ShardingSpec( + device_mesh=logical_sharding_spec.device_mesh, + entire_shape=physical_shape, + dim_partition_dict=physical_dim_partition, + ) return physical_sharding_spec, removed_dims -def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData, - sharding_spec: ShardingSpec) -> CommAction: +def comm_actions_for_oprands( + node: Node, removed_dims: List[int], op_data: OperationData, sharding_spec: ShardingSpec +) -> CommAction: """ This method is used to generate communication actions for oprands which lose information during convert logical shape to physical shape. @@ -140,9 +148,11 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera if len(removed_dims) == 1: # if list length is 1, extract element from list to avoid using flatten device mesh removed_dims = removed_dims[0] - comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - sharding_spec=sharding_spec, - logical_process_axis=removed_dims) + comm_spec = CommSpec( + comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + sharding_spec=sharding_spec, + logical_process_axis=removed_dims, + ) if op_data.type == OperationDataType.PARAM: comm_type = CommType.HOOK else: @@ -151,7 +161,7 @@ def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: Opera for index, arg in enumerate(node.args): if op_data.name == str(arg): arg_index = index - assert arg_index >= 0, f'op_data should be an argument of node.' + assert arg_index >= 0, f"op_data should be an argument of node." comm_action = CommAction( comm_spec=comm_spec, comm_type=comm_type, diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py index 347c10aa102d..aaca923a5eee 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/factory.py +++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py @@ -14,11 +14,12 @@ from ..constants import INFINITY_COST -__all__ = ['generate_sharding_spec', 'generate_resharding_costs'] +__all__ = ["generate_sharding_spec", "generate_resharding_costs"] -def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, - dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: +def generate_sharding_spec( + input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, dim_partition_dict: Dict[int, List[int]] +) -> ShardingSpec: """ Generate the sharding spec of the tensor based on the given dim_partition_dict. @@ -30,7 +31,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic """ if isinstance(input_, Node): - assert hasattr(input_, '_meta_data'), f'The given node has no attribute _meta_data' + assert hasattr(input_, "_meta_data"), f"The given node has no attribute _meta_data" meta_tensor = input_._meta_data assert meta_tensor is not None, "The given node's _meta_data attribute is None" shape = meta_tensor.shape @@ -38,24 +39,27 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic shape = input_.shape else: raise TypeError( - f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' + f"We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected." ) for dim_index, sharding_index_list in dim_partition_dict.items(): sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] sharding_size = reduce(operator.mul, sharding_list, 1) - assert shape[ - dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' + assert ( + shape[dim_index] % sharding_size == 0 + ), f"we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions." sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) return sharding_spec -def generate_resharding_costs(nodes: List[Node], - sharding_specs: List[ShardingSpec], - count_backward: Optional[bool] = True, - dtype: Optional[torch.dtype] = None, - index=None): - ''' +def generate_resharding_costs( + nodes: List[Node], + sharding_specs: List[ShardingSpec], + count_backward: Optional[bool] = True, + dtype: Optional[torch.dtype] = None, + index=None, +): + """ Compute the resharding costs with this specific strategy. Argument: @@ -63,7 +67,7 @@ def generate_resharding_costs(nodes: List[Node], sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. - ''' + """ # The resharding_cost of weight is counted due to sharing weight cases. resharding_costs = {} size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() @@ -76,38 +80,39 @@ def generate_resharding_costs(nodes: List[Node], for strategy in input_node.strategies_vector: input_sharding_spec = strategy.output_sharding_spec if not isinstance(input_sharding_spec, ShardingSpec): - assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' + assert isinstance(input_sharding_spec, list), "only ShardingSpec or List[ShardingSpec] is expected." input_sharding_spec = input_sharding_spec[index] - assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + assert isinstance(input_sharding_spec, ShardingSpec), f"The input node should NOT be a tuple of tensor." try: # compute the resharding cost _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( - input_sharding_spec, input_spec) + input_sharding_spec, input_spec + ) # we need multiply the size of elem dtype to get correct communication cost resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes except AssertionError as e: - warnings.warn(f'{e}') + warnings.warn(f"{e}") resharding_cost = INFINITY_COST resharding_costs[input_node].append(resharding_cost) return resharding_costs def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20): - ''' + """ Find the largest repeat blocks in the graph, whose length is larger than the threshold. Args: gm (GraphModule): the graph module to be analyzed. common_length_threshold (int): the threshold of the repeat block length. - ''' + """ # graph = gm.graph def _process_args(args): new_args = [] for arg in args: - if hasattr(arg, '_meta_data'): + if hasattr(arg, "_meta_data"): meta_data = arg._meta_data else: meta_data = arg @@ -145,7 +150,7 @@ def _check_node_equal(node1, node2): return False for index, node in enumerate(node_list): - if node.op == 'call_module': + if node.op == "call_module": target = node.target submod = root_module.get_submodule(target) submod_type = type(submod) @@ -155,12 +160,12 @@ def _check_node_equal(node1, node2): new_args = _process_args(node.args) - if node.op != 'get_attr': + if node.op != "get_attr": hash_key = (node.op, target, *new_args) else: hash_key = (node.op,) - setattr(node, 'hash_key', hash_key) + setattr(node, "hash_key", hash_key) hash_value_to_node_dict = {} @@ -179,7 +184,7 @@ def _check_node_equal(node1, node2): # the comparison will be triggered if a common node appears if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2: start_index_list = hash_value_to_node_dict[hash(node.hash_key)] - check_block_list = [node_list[start:start + max_common_length] for start in start_index_list] + check_block_list = [node_list[start : start + max_common_length] for start in start_index_list] common_label = True if not _all_equal(check_block_list, _check_node_list_equal): @@ -201,6 +206,6 @@ def _check_node_equal(node1, node2): # recover common subgraph from the index common_blocks = [] for start in common_blocks_index: - common_blocks.append(node_list[start:start + max_common_length]) + common_blocks.append(node_list[start : start + max_common_length]) return common_blocks diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 475e95fc4326..42ec2a8ee428 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -1,12 +1,12 @@ import functools -from typing import Any, Callable, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Tuple, Type, Union import torch from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException -__all__ = ['ignore_sharding_exception', 'pytree_map'] +__all__ = ["ignore_sharding_exception", "pytree_map"] def ignore_sharding_exception(func): @@ -48,29 +48,32 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens tensor_num_dim = tensor.dim() num_devices_in_col = sharding_spec.device_mesh.shape[0] num_devices_in_row = sharding_spec.device_mesh.shape[1] - assert sharding_len == tensor_num_dim, \ - f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' + assert ( + sharding_len == tensor_num_dim + ), f"The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape})." # make sure the sharding is valid for each dim for i in range(tensor_num_dim): dim_size = tensor.shape[i] dim_spec = sharding_spec.sharding_sequence[i] - if str(dim_spec).startswith('S'): - devices_str = str(dim_spec).lstrip('S') + if str(dim_spec).startswith("S"): + devices_str = str(dim_spec).lstrip("S") num_devices = 1 - if '0' in devices_str: + if "0" in devices_str: num_devices *= num_devices_in_col - if '1' in devices_str: + if "1" in devices_str: num_devices *= num_devices_in_row - assert dim_size >= num_devices and dim_size % num_devices == 0, \ - f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' + assert ( + dim_size >= num_devices and dim_size % num_devices == 0 + ), f"The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices." # make sure the entire shape matches the physical tensor shape - assert sharding_spec.entire_shape == tensor.shape, \ - f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' + assert ( + sharding_spec.entire_shape == tensor.shape + ), f"The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}" def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py index d0ebbd7e8b1b..329312ef797f 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/reshape.py +++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py @@ -8,6 +8,7 @@ class PreviousStatus(Enum): """ This class shows the status of previous comparison. """ + RESET = 0 # ORIGIN means the dimension size of original tensor is larger in the previous comparison. ORIGIN = 1 @@ -130,8 +131,9 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D return reshape_mapping_dict -def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], - reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool: +def check_keep_sharding_status( + input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]] +) -> bool: """ This method is used to check whether the reshape operation could implement without converting the input to fully replicated status. @@ -172,14 +174,16 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], return True -def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]], - reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]: +def infer_output_dim_partition_dict( + input_dim_partition_dict: Dict[int, List[int]], reshape_mapping_dict: Dict[Tuple[int], Tuple[int]] +) -> Dict[Tuple[int], Tuple[int]]: """ This method is used to infer the output dim partition dict for a reshape operation, given the input dim partition dict and reshape mapping dict. """ - assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \ - 'we only infer output dim partition dict for the reshape operation could keep sharding spec.' + assert check_keep_sharding_status( + input_dim_partition_dict, reshape_mapping_dict + ), "we only infer output dim partition dict for the reshape operation could keep sharding spec." sharded_dims = list(input_dim_partition_dict.keys()) output_dim_partition_dict = {} for input_dims, output_dims in reshape_mapping_dict.items(): diff --git a/colossalai/auto_parallel/tensor_shard/utils/sharding.py b/colossalai/auto_parallel/tensor_shard/utils/sharding.py index e2ce59e0b577..b5386d599be4 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/sharding.py +++ b/colossalai/auto_parallel/tensor_shard/utils/sharding.py @@ -8,8 +8,11 @@ from colossalai.tensor.sharding_spec import ShardingSpec __all__ = [ - 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' + "transpose_partition_dim", + "update_partition_dim", + "enumerate_all_possible_1d_sharding", + "enumerate_all_possible_2d_sharding", + "generate_sharding_size", ] @@ -22,8 +25,7 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) - dim1 (int): the tensor dimension to switch dim2 (int): the tensor dimension to switch """ - assert len(sharding_spec.entire_shape) >= 2, \ - 'The entire_shape of the sharding spec must have at least 2 dimensions' + assert len(sharding_spec.entire_shape) >= 2, "The entire_shape of the sharding spec must have at least 2 dimensions" dim_partition_dict = sharding_spec.dim_partition_dict # transpose the dim partition @@ -45,10 +47,9 @@ def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) - return sharding_spec -def update_partition_dim(sharding_spec: ShardingSpec, - dim_mapping: Dict[int, int], - physical_shape: torch.Size, - inplace: bool = False): +def update_partition_dim( + sharding_spec: ShardingSpec, dim_mapping: Dict[int, int], physical_shape: torch.Size, inplace: bool = False +): """ This method is used to update the partition dim dict from the logical one to the physical one. @@ -78,9 +79,9 @@ def update_partition_dim(sharding_spec: ShardingSpec, new_dim_partition_dict[tensor_dim] = mesh_dims # update sharding spec - current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh, - entire_shape=physical_shape, - dim_partition_dict=new_dim_partition_dict) + current_sharding_spec.__init__( + device_mesh=sharding_spec.device_mesh, entire_shape=physical_shape, dim_partition_dict=new_dim_partition_dict + ) return current_sharding_spec diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index cc98c1570b4a..9571fa2c17f0 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -9,7 +9,18 @@ AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta() if AUTOCHUNK_AVAILABLE: - from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods + from torch.fx.graph import ( + CodeGen, + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + inplace_methods, + magic_methods, + ) from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg @@ -64,14 +75,21 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out for i in range(len(chunk_output)): shape_str = str(list(get_node_shape(chunk_output[i]))) if get_node_name(chunk_output[i]) in ["split", "unbind"]: - tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name, - input_node.name) - tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta']) + tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % ( + shape_str, + input_node.name, + input_node.name, + ) + tensor_str = tensor_str * len(chunk_output[i].meta["tensor_meta"]) tensor_str = "[" + tensor_str[:-2] + "]" context += "%s = %s; " % (chunk_output[i].name, tensor_str) else: - context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str, - input_node.name, input_node.name) + context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % ( + chunk_output[i].name, + shape_str, + input_node.name, + input_node.name, + ) out_shape = get_node_shape(chunk_output[0]) chunk_shape = out_shape[chunk_output_dim[0]] @@ -79,8 +97,14 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_out return context -def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node], - chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str: +def _gen_loop_end( + chunk_inputs: List[Node], + chunk_non_compute_inputs: List[Node], + node_list: List[Node], + chunk_outputs_idx: int, + chunk_outputs_non_tensor: List[Node], + search_chunk: SearchChunk, +) -> str: """ Generate chunk loop end @@ -148,8 +172,10 @@ def _replace_new_tensor_like_shape( chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] if get_node_shape(meta_node)[chunk_dim] != 1: source_node = meta_node.args[0].args[0] - if (source_node not in chunk_infos[region_idx]["node_chunk_dim"] - or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None): + if ( + source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None + ): chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node)) body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice) return body @@ -203,11 +229,12 @@ def _add_node_slice( # outputs node else: if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]): - chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", - get_node_shape(chunk_node)) + chunk_slice = _gen_chunk_slice_dim( + chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx", get_node_shape(chunk_node) + ) if get_node_name(chunk_node) in ["split", "unbind"]: split_chunk_slice = "" - for i in range(len(chunk_node.meta['tensor_meta'])): + for i in range(len(chunk_node.meta["tensor_meta"])): split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice) split_chunk_slice = split_chunk_slice[:-2] body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice) @@ -216,13 +243,15 @@ def _add_node_slice( return body -def emit_code_with_chunk(body: List[str], - nodes: Iterable[Node], - emit_node_func: Callable, - delete_unused_value_func: Callable, - search_chunk: SearchChunk, - chunk_infos: List, - eval_mem: bool = False): +def emit_code_with_chunk( + body: List[str], + nodes: Iterable[Node], + emit_node_func: Callable, + delete_unused_value_func: Callable, + search_chunk: SearchChunk, + chunk_infos: List, + eval_mem: bool = False, +): """ Emit code with chunk according to chunk_infos. @@ -244,9 +273,9 @@ def emit_code_with_chunk(body: List[str], chunk_ends = [i["region"][1] for i in chunk_infos] # chunk inputs - chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk - chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim + chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk + chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk + chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i] # chunk outputs @@ -275,7 +304,8 @@ def emit_code_with_chunk(body: List[str], chunk_outputs[region_idx], chunk_outputs_dim[region_idx], chunk_infos[region_idx]["chunk_size"], - )) + ) + ) if within_chunk_region: emit_node_func(node, body) @@ -294,7 +324,8 @@ def emit_code_with_chunk(body: List[str], if eval_mem: body.append( " if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" - % (node.name)) + % (node.name) + ) else: emit_node_func(node, body) if node_idx not in chunk_inputs: @@ -302,13 +333,21 @@ def emit_code_with_chunk(body: List[str], if eval_mem: body.append( "print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" - % (node.name)) + % (node.name) + ) # generate chunk region end if node_idx in chunk_ends: body.append( - _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list, - chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk)) + _gen_loop_end( + chunk_inputs[region_idx], + chunk_inputs_non_chunk[region_idx], + node_list, + chunk_ends[region_idx], + chunk_outputs_non_tensor[region_idx], + search_chunk, + ) + ) within_chunk_region = False node_idx += 1 @@ -317,13 +356,14 @@ def emit_code_with_chunk(body: List[str], if AUTOCHUNK_AVAILABLE: class AutoChunkCodeGen(CodeGen): - - def __init__(self, - meta_graph, - max_memory: int = None, - print_mem: bool = False, - print_progress: bool = False, - eval_mem: bool = False) -> None: + def __init__( + self, + meta_graph, + max_memory: int = None, + print_mem: bool = False, + print_progress: bool = False, + eval_mem: bool = False, + ) -> None: super().__init__() self.eval_mem = eval_mem # find the chunk regions @@ -349,7 +389,7 @@ def add_global(name_hint: str, obj: Any): Returns: the global name that should be used to reference 'obj' in generated source. """ - if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -402,7 +442,6 @@ def type_repr(o: Any): return add_global(typename, o) def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. if isinstance(arg, tuple) and hasattr(arg, "_fields"): @@ -457,10 +496,10 @@ def delete_unused_values(user: Node, body, to_keep=[]): # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}") + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}") + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") raw_name = node.target.replace("*", "") if raw_name != repr(node): @@ -470,42 +509,56 @@ def emit_node(node: Node, body): assert isinstance(node.target, str) body.append( f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})") + f"({_format_args(node.args[1:], node.kwargs)})" + ) return elif node.op == "call_function": assert callable(node.target) # pretty print operators - if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods): + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}") + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods): - body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}") + if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods: + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str) - and node.args[1].isidentifier() and len(node.args) == 2): + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}") + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})") + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})") + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return elif node.op == "get_attr": assert isinstance(node.target, str) @@ -523,8 +576,9 @@ def emit_node(node: Node, body): # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, - self.eval_mem) + emit_code_with_chunk( + body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, self.eval_mem + ) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index 77bc2ef17bc3..a85ad429e261 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -1,11 +1,8 @@ -import copy -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Dict, List import torch from torch.fx.node import Node -from colossalai.fx.profiler import activation_size, parameter_size - from .utils import NodeMgr, get_node_shape, is_non_memory_node @@ -62,12 +59,9 @@ def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict: delete_node_dict[node] = max(node_user_idx) return delete_node_dict - def _remove_deactive_node(self, - user_idx: int, - user: Node, - active_nodes: List, - delete_node_dict: List, - kept_nodes: List = None) -> None: + def _remove_deactive_node( + self, user_idx: int, user: Node, active_nodes: List, delete_node_dict: List, kept_nodes: List = None + ) -> None: """ remove deactivate nodes from active nodes """ @@ -169,7 +163,7 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None use_chunk = True if chunk_infos is not None else False chunk_within = False chunk_region_idx = None - chunk_ratio = 1 # use it to estimate chunk mem + chunk_ratio = 1 # use it to estimate chunk mem chunk_inputs_all = [] if use_chunk: @@ -184,7 +178,6 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos] for idx, node in enumerate(node_mgr.get_node_list()): - # if node in chunk start nodes, change chunk ratio and add chunk_tensor if use_chunk and idx in chunk_starts: chunk_within = True @@ -193,8 +186,9 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None # determine chunk ratio for current node if chunk_within: - chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx], - chunk_sizes[chunk_region_idx]) + chunk_ratio = self._get_chunk_ratio( + node, chunk_node_dim[chunk_region_idx], chunk_sizes[chunk_region_idx] + ) # add current node as active node self._add_active_node(node, active_nodes, chunk_ratio) @@ -222,7 +216,7 @@ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None # if node in chunk end nodes, restore chunk settings if use_chunk and idx in chunk_ends: - self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now + self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now chunk_within = False chunk_ratio = 1 chunk_region_idx = None diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index 59645c80e808..1c599049d9eb 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -8,7 +8,7 @@ from .select_chunk import SelectChunk from .trace_flow import TraceFlow from .trace_indice import TraceIndice -from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder +from .utils import NodeMgr, get_logger, is_non_compute_node, is_non_compute_node_except_placeholder class SearchChunk(object): @@ -121,8 +121,10 @@ def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_re # check if peak node already in chunk info if chunk_regions is not None: for i in chunk_regions: - if i["region"][0] < peak_region[0] <= i["region"][1] or \ - i["region"][0] < peak_region[1] <= i["region"][1]: + if ( + i["region"][0] < peak_region[0] <= i["region"][1] + or i["region"][0] < peak_region[1] <= i["region"][1] + ): return None active_node_num = [len(i) for i in active_node] @@ -146,9 +148,9 @@ def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_re region = i["region"] if chunk_region_start >= region[0] and chunk_region_end <= region[1]: return None - elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]): + elif region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]: chunk_region_start = region[1] + 1 - elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]): + elif region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]: chunk_region_end = region[0] - 1 return chunk_region_start, chunk_region_end @@ -171,7 +173,7 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis chunk_infos: possible regions found """ start_traces = input_trace[start_idx] - if len(start_traces) > 1: # TODO need to be removed + if len(start_traces) > 1: # TODO need to be removed return [] end_trace = output_trace[end_idx] end_node = self.node_mgr.get_node_by_idx(end_idx) @@ -180,8 +182,9 @@ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> Lis for end_dim, _ in enumerate(end_trace["indice"]): for start_node, start_trace in start_traces.items(): for start_dim, _ in enumerate(start_trace["indice"]): - if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, - end_idx): + if not self.trace_flow.check_region_start_end( + start_node, start_dim, start_idx, end_node, end_dim, end_idx + ): continue # flow search chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim) @@ -203,7 +206,7 @@ def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: N """ possible_chunk_region = [] output_trace = copy.deepcopy(self.trace_indice.indice_trace_list) - input_trace = [] # trace of a node's input nodes + input_trace = [] # trace of a node's input nodes for _, n in enumerate(self.node_mgr.get_node_list()): cur_trace = {} for arg in n.args: @@ -215,7 +218,8 @@ def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: N for end_idx in range(peak_region[1], max_chunk_region[1] + 1): # skip non compute nodes if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node( - self.node_mgr.get_node_by_idx(end_idx)): + self.node_mgr.get_node_by_idx(end_idx) + ): continue # select free dim chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx) @@ -279,15 +283,18 @@ def search_region(self) -> Dict: chunk_infos.append(chunk_info) mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem( - self.node_mgr.get_node_list(), chunk_infos) + self.node_mgr.get_node_list(), chunk_infos + ) if self.print_progress: - get_logger().info("AutoChunk find chunk region %d = (%d, %d)" % - (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1])) + get_logger().info( + "AutoChunk find chunk region %d = (%d, %d)" + % (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]) + ) if self.print_mem: self.print_mem = False - self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), - chunk_infos, - print_mem=True) + self.estimate_memory.estimate_chunk_inference_mem( + self.node_mgr.get_node_list(), chunk_infos, print_mem=True + ) return chunk_infos diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index 94a29bfd5691..8a60ba681f70 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -5,7 +5,6 @@ class SelectChunk(object): - def __init__( self, trace_indice: TraceIndice, @@ -20,7 +19,7 @@ def __init__( self.node_mgr = node_mgr if max_memory is not None: self.stratge = "fit_memory" - self.max_memory = max_memory # MB + self.max_memory = max_memory # MB else: self.stratge = "min_memory" @@ -57,16 +56,18 @@ def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, m cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_chunk_infos = chunk_infos + [cur_region] cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] - cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1] + cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1] cur_chunk_region_max_peak = max(cur_chunk_region_peak) if cur_chunk_region_max_peak < self.max_memory: - regions_dict.append({ - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - }) + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) # no region found if len(regions_dict) == 0: raise RuntimeError("Search failed. Try a larger memory threshold.") @@ -90,13 +91,15 @@ def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): chunk_size *= 2 reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_infos = chunk_infos + [reorder_chunk_info] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], - cur_chunk_infos)[0] - cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1]) + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1]) # search exact size chunk_info = chunk_region_dict["chunk_info"] - chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict, - chunk_infos) + chunk_info["chunk_size"] = self._chunk_size_binary_search( + chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos + ) return chunk_info def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): @@ -109,9 +112,10 @@ def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos) mid = int((left + right) / 2 + 0.5) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"], - cur_chunk_infos)[0] - cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1]) + cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( + chunk_region_dict["reorder_node_list"], cur_chunk_infos + )[0] + cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]) if cur_chunk_max_mem >= self.max_memory: right = mid - gap else: @@ -139,8 +143,10 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): return None # get max possible chunk region - max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]), - max([i["region"][1] for i in possible_chunk_regions])) + max_possible_chunk_region = ( + min([i["region"][0] for i in possible_chunk_regions]), + max([i["region"][1] for i in possible_chunk_regions]), + ) # get mem for chunk region regions_dict_list = [] @@ -149,15 +155,17 @@ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_chunk_infos = chunk_infos + [cur_region] cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] - cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] + cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1] cur_chunk_region_max_peak = max(cur_chunk_region_peak) - regions_dict_list.append({ - "chunk_info": region, - "chunk_max_mem": cur_chunk_region_max_peak, - "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), - "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list, - }) + regions_dict_list.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) # select the min mem chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list] @@ -175,7 +183,9 @@ def _is_legal_region(self, cur_chunk_info, chunk_infos): return False for i in chunk_infos: region = i["region"] - if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or - (chunk_region_start < region[0] and chunk_region_end < region[0])): + if not ( + (chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0]) + ): return False return True diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py index a1080fda1541..8b36c99bbadd 100644 --- a/colossalai/autochunk/trace_flow.py +++ b/colossalai/autochunk/trace_flow.py @@ -16,7 +16,6 @@ class TraceFlow(object): - def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None: self.trace_indice = trace_indice self.node_mgr = node_mgr @@ -151,7 +150,7 @@ def _assign_single_node_flow( return True def _get_all_node_info(self, end_dim, start_idx, end_idx): - cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node + cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}} while len(cur_node_list) > 0: @@ -266,7 +265,7 @@ def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, maybe_prepose_nodes.sort( key=lambda x: self.node_mgr.find_node_idx(x), reverse=True, - ) # from last node to first node + ) # from last node to first node prepose_nodes = [] # set every node as root, search its args, if all legal, turn root and args as prepose nodes while len(maybe_prepose_nodes) > 0: @@ -328,7 +327,8 @@ def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): def flow_search(self, start_idx, start_dim, end_idx, end_dim): inputs, outputs = find_chunk_compute_input_and_output_nodes( - self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)) + self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1) + ) # get every node's chunk dim and fix dim all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) @@ -371,8 +371,9 @@ def flow_search(self, start_idx, start_dim, end_idx, end_dim): return chunk_info - def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, - chunk_info: Dict): + def _get_other_output_info( + self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict + ): start_node = self.node_mgr.get_node_by_idx(start_idx) # loop all outputs for output in outputs: @@ -384,8 +385,8 @@ def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: # skip non tensor if get_node_shape(output) is None: # log shape tensor - if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int): - chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out']) + if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int): + chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"]) continue # loop every dim of outputs, try to find a legal one for output_dim in range(len(get_node_shape(output))): @@ -421,7 +422,8 @@ def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: for k, v in new_all_node_info.items(): if k in chunk_info["node_chunk_dim"]: chunk_info["node_chunk_dim"][k]["fix_dim"] = list( - set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])) + set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]) + ) else: chunk_info["node_chunk_dim"][k] = v chunk_info["outputs"].append(output) @@ -443,8 +445,11 @@ def _reassign_reshape_size(self, chunk_info): if node.args[0] in chunk_info["inputs_non_chunk"]: continue reshape_args = flat_list(node.args[1:]) - if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len( - reshape_args[0].meta['fwd_out']) > 1: + if ( + len(reshape_args) == 1 + and get_node_shape(reshape_args[0]) is None + and len(reshape_args[0].meta["fwd_out"]) > 1 + ): continue chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"] new_shape = "" @@ -462,16 +467,17 @@ def _reassign_reshape_size(self, chunk_info): chunk_info["reshape_size"] = reshape_size return chunk_info - def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, - end_idx: int) -> bool: + def check_region_start_end( + self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int + ) -> bool: """ check if region start and end is legal """ # dim cannot be None - if (get_node_shape(end_node) is None or get_node_shape(start_node) is None): + if get_node_shape(end_node) is None or get_node_shape(start_node) is None: return False # dim size cannot be 1 - if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1): + if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1: return False # must have users if len(end_node.users) == 0: diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index fbe0741b8827..378c54acf782 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, List, Tuple +from typing import Dict, List from torch.fx.node import Node @@ -412,7 +412,7 @@ def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None: node_idx (int) """ # get conv input - assert node.kwargs['size'] is None + assert node.kwargs["size"] is None assert len(get_node_shape(node)) == 4 # assign index @@ -826,7 +826,7 @@ def _clear_trace(self, node_idx: int) -> None: # clear compute for dim_compute in trace["compute"]: for i in range(len(dim_compute) - 1, -1, -1): - if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes): + if dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes: dim_compute.pop(i) continue # clear source @@ -876,10 +876,24 @@ def trace_indice(self) -> None: self._assign_matmul_indice(node, idx) elif "softmax" == node_name: self._assign_softmax_indice(node, idx) - elif any(n == node_name for n in [ - "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp", - "sin", "cos" - ]): + elif any( + n == node_name + for n in [ + "mul", + "add", + "sigmoid", + "relu", + "sub", + "truediv", + "pow", + "dropout", + "where", + "tanh", + "exp", + "sin", + "cos", + ] + ): self._assign_elementwise_indice(node, idx) elif "einsum" == node_name: self._assign_einsum_indice(node, idx) @@ -920,7 +934,7 @@ def trace_indice(self) -> None: else: raise NotImplementedError(node_name, "module not implemented yet!") elif node.op == "get_attr": - self._assign_all_indice(node, idx) # get param + self._assign_all_indice(node, idx) # get param elif node.op == "output": continue else: diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index 064baa047155..f6f803a5ce0a 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import Any, Dict, List, Union from torch.fx.node import Node @@ -10,7 +10,6 @@ class NodeMgr(object): - def __init__(self, nodes_list: List[Node]) -> None: self._node_list = nodes_list self._node_dict = {} @@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, # we treat that input node as the input of the checkpoint function for node in nodes: for input_node in node._input_nodes.keys(): - if (input_node not in nodes and input_node not in input_nodes - and not is_non_compute_node_except_placeholder(input_node)): + if ( + input_node not in nodes + and input_node not in input_nodes + and not is_non_compute_node_except_placeholder(input_node) + ): input_nodes.append(input_node) # if a node has a user node which is not in the node list # we treat that user node as the node receiving the current node output for node in nodes: for output_node in node.users.keys(): - if (output_node not in nodes and node not in output_nodes - and not is_non_compute_node_except_placeholder_output(output_node)): + if ( + output_node not in nodes + and node not in output_nodes + and not is_non_compute_node_except_placeholder_output(output_node) + ): output_nodes.append(node) return input_nodes, output_nodes @@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]: for node in node_list: if get_node_shape(node) is not None: out.append(node) - elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance( - node.meta['fwd_out'][0], int): + elif ( + len(node.meta["fwd_out"]) > 0 + and isinstance(node.meta["fwd_out"], list) + and isinstance(node.meta["fwd_out"][0], int) + ): out.append(node) return out diff --git a/colossalai/booster/accelerator.py b/colossalai/booster/accelerator.py index fc2c4a40068b..92990907bc2e 100644 --- a/colossalai/booster/accelerator.py +++ b/colossalai/booster/accelerator.py @@ -1,12 +1,11 @@ import torch import torch.nn as nn -__all__ = ['Accelerator'] +__all__ = ["Accelerator"] _supported_devices = [ - 'cpu', - 'cuda', - + "cpu", + "cuda", # To be supported # 'xpu', # 'npu', @@ -25,21 +24,22 @@ class Accelerator: def __init__(self, device: str): self.device = device - assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}" + assert ( + self.device in _supported_devices + ), f"Device {self.device} is not supported yet, supported devices include {_supported_devices}" def bind(self): """ Set the default device for the current process. """ - if self.device == 'cpu': + if self.device == "cpu": pass - elif self.device == 'cuda': + elif self.device == "cuda": # TODO(FrankLeeeee): use global environment to check if it is a dist job # if is_distributed: # local_rank = EnvTable().get_local_rank() # torch.cuda.set_device(torch.device(f'cuda:{local_rank}')) - torch.cuda.set_device(torch.device('cuda')) - pass + torch.cuda.set_device(torch.device("cuda")) else: raise ValueError(f"Device {self.device} is not supported yet") diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index fb9dae7c9650..2aee72cbf2f1 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -16,7 +16,7 @@ from .plugin import Plugin from .plugin.pp_plugin_base import PipelinePluginBase -__all__ = ['Booster'] +__all__ = ["Booster"] class Booster: @@ -60,28 +60,31 @@ class Booster: plugin (Plugin): The plugin to run the training. Default: None. """ - def __init__(self, - device: Optional[str] = None, - mixed_precision: Optional[Union[MixedPrecision, str]] = None, - plugin: Optional[Plugin] = None) -> None: + def __init__( + self, + device: Optional[str] = None, + mixed_precision: Optional[Union[MixedPrecision, str]] = None, + plugin: Optional[Plugin] = None, + ) -> None: if plugin is not None: assert isinstance( - plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.' + plugin, Plugin + ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." self.plugin = plugin # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None if device is not None: - warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') + warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") else: - device = device or 'cuda' + device = device or "cuda" self.accelerator = Accelerator(device) # set precision if self.plugin and self.plugin.control_precision(): if mixed_precision is not None: - warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') + warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None @@ -95,7 +98,7 @@ def __init__(self, self.mixed_precision = mixed_precision else: raise ValueError( - f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' + f"Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}." ) if self.plugin is not None and self.plugin.control_checkpoint_io(): @@ -131,7 +134,8 @@ def boost( # transform model for mixed precision if self.plugin: model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( - model, optimizer, criterion, dataloader, lr_scheduler) + model, optimizer, criterion, dataloader, lr_scheduler + ) if self.plugin and not self.plugin.control_device(): # transform model for accelerator @@ -154,13 +158,15 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: # TODO(frank lee): implement this method with plugin optimizer.backward(loss) - def execute_pipeline(self, - data_iter: Iterator, - model: nn.Module, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[Optimizer] = None, - return_loss: bool = True, - return_outputs: bool = False) -> Dict[str, Any]: + def execute_pipeline( + self, + data_iter: Iterator, + model: nn.Module, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[Optimizer] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> Dict[str, Any]: """ Execute forward & backward when utilizing pipeline parallel. Return loss or Huggingface style model outputs if needed. @@ -185,8 +191,9 @@ def execute_pipeline(self, ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None. """ - assert isinstance(self.plugin, - PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.' + assert isinstance( + self.plugin, PipelinePluginBase + ), f"The plugin {self.plugin.__class__.__name__} does not support pipeline." return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs) def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: @@ -200,8 +207,10 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) - Returns: contextmanager: Context to disable gradient synchronization. """ - assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' - assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' + assert ( + self.plugin is not None + ), f"no_sync is only enabled when a plugin is provided and the plugin supports no_sync." + assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync." return self.plugin.no_sync(model, optimizer) def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None: @@ -217,14 +226,16 @@ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, str """ self.checkpoint_io.load_model(model, checkpoint, strict) - def save_model(self, - model: Union[nn.Module, ModelWrapper], - checkpoint: str, - shard: bool = False, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False) -> None: + def save_model( + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: """Save model to checkpoint. Args: @@ -239,13 +250,15 @@ def save_model(self, size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. """ - self.checkpoint_io.save_model(model, - checkpoint=checkpoint, - shard=shard, - gather_dtensor=gather_dtensor, - prefix=prefix, - size_per_shard=size_per_shard, - use_safetensors=use_safetensors) + self.checkpoint_io.save_model( + model, + checkpoint=checkpoint, + shard=shard, + gather_dtensor=gather_dtensor, + prefix=prefix, + size_per_shard=size_per_shard, + use_safetensors=use_safetensors, + ) def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: """Load optimizer from checkpoint. @@ -260,13 +273,15 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None: """ self.checkpoint_io.load_optimizer(optimizer, checkpoint) - def save_optimizer(self, - optimizer: Optimizer, - checkpoint: str, - shard: bool = False, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024) -> None: + def save_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ) -> None: """ Save optimizer to checkpoint. diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py index 0df9d84159f9..68c6221ec809 100644 --- a/colossalai/booster/mixed_precision/__init__.py +++ b/colossalai/booster/mixed_precision/__init__.py @@ -6,16 +6,22 @@ from .mixed_precision_base import MixedPrecision __all__ = [ - 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision', - 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision' + "MixedPrecision", + "mixed_precision_factory", + "FP16_Apex_MixedPrecision", + "FP16_Torch_MixedPrecision", + "FP32_MixedPrecision", + "BF16_MixedPrecision", + "FP8_MixedPrecision", + "FP16NaiveMixedPrecision", ] _mixed_precision_mapping = { - 'fp16': FP16TorchMixedPrecision, - 'fp16_apex': FP16ApexMixedPrecision, - 'fp16_naive': FP16NaiveMixedPrecision, - 'bf16': BF16MixedPrecision, - 'fp8': FP8MixedPrecision + "fp16": FP16TorchMixedPrecision, + "fp16_apex": FP16ApexMixedPrecision, + "fp16_naive": FP16NaiveMixedPrecision, + "bf16": BF16MixedPrecision, + "fp8": FP8MixedPrecision, } @@ -31,5 +37,5 @@ def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision: return _mixed_precision_mapping[mixed_precision_type]() else: raise ValueError( - f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}' + f"Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}" ) diff --git a/colossalai/booster/mixed_precision/fp16_apex.py b/colossalai/booster/mixed_precision/fp16_apex.py index e184271e932a..2fa7b54cdd30 100644 --- a/colossalai/booster/mixed_precision/fp16_apex.py +++ b/colossalai/booster/mixed_precision/fp16_apex.py @@ -23,16 +23,18 @@ class FP16ApexMixedPrecision(MixedPrecision): max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored. """ - def __init__(self, - opt_level: Optional[str] = "O1", - cast_model_type: torch.dtype = None, - patch_torch_functions: bool = None, - keep_batchnorm_fp32: Union[bool, str] = None, - master_weights: bool = None, - loss_scale: Union[float, str] = None, - cast_model_outputs: Any = None, - num_losses: Optional[int] = 1, - verbosity: int = 1, - min_loss_scale: float = None, - max_loss_scale: float = 2.**24) -> None: + def __init__( + self, + opt_level: Optional[str] = "O1", + cast_model_type: torch.dtype = None, + patch_torch_functions: bool = None, + keep_batchnorm_fp32: Union[bool, str] = None, + master_weights: bool = None, + loss_scale: Union[float, str] = None, + cast_model_outputs: Any = None, + num_losses: Optional[int] = 1, + verbosity: int = 1, + min_loss_scale: float = None, + max_loss_scale: float = 2.0**24, + ) -> None: pass diff --git a/colossalai/booster/mixed_precision/fp16_naive.py b/colossalai/booster/mixed_precision/fp16_naive.py index 5d0d815257f3..e5624a9d7477 100644 --- a/colossalai/booster/mixed_precision/fp16_naive.py +++ b/colossalai/booster/mixed_precision/fp16_naive.py @@ -15,12 +15,14 @@ class FP16NaiveMixedPrecision(MixedPrecision): verbose(bool): if set to `True`, will print debug info. """ - def __init__(self, - log_num_zeros_in_grad: bool, - initial_scale: int, - growth_factor: int, - backoff_factor: float, - hysteresis: int, - max_scale: int, - verbose: bool = None) -> None: + def __init__( + self, + log_num_zeros_in_grad: bool, + initial_scale: int, + growth_factor: int, + backoff_factor: float, + hysteresis: int, + max_scale: int, + verbose: bool = None, + ) -> None: pass diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py index 26fd92bd50b8..7dce6e6da33e 100644 --- a/colossalai/booster/mixed_precision/fp16_torch.py +++ b/colossalai/booster/mixed_precision/fp16_torch.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -9,7 +9,7 @@ from .mixed_precision_base import MixedPrecision -__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule'] +__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"] class TorchAMPOptimizer(OptimizerWrapper): @@ -29,17 +29,21 @@ class TorchAMPOptimizer(OptimizerWrapper): calls that may cause the scale to increase. Default: 2000. """ - def __init__(self, - optim: Optimizer, - init_scale: float = 2.**16, - growth_factor: float = 2.0, - backoff_factor: float = 0.5, - growth_interval: int = 2000) -> None: + def __init__( + self, + optim: Optimizer, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + ) -> None: super().__init__(optim) - self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval) + self.scaler = torch.cuda.amp.GradScaler( + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + ) def backward(self, loss: Tensor, *args, **kwargs) -> None: scaled_loss = self.scale_loss(loss) @@ -60,12 +64,14 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: self.unscale_grad() super().clip_grad_by_value(clip_value, *args, **kwargs) - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> None: + def clip_grad_by_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs, + ) -> None: self.unscale_grad() super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs) @@ -102,22 +108,27 @@ class FP16TorchMixedPrecision(MixedPrecision): calls that may cause the scale to increase. Default: 2000. """ - def __init__(self, - init_scale: float = 2.**16, - growth_factor: float = 2.0, - backoff_factor: float = 0.5, - growth_interval: int = 2000) -> None: + def __init__( + self, + init_scale: float = 2.0**16, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + ) -> None: super().__init__() - self.torch_amp_kwargs = dict(init_scale=init_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval) - - def configure(self, - model: nn.Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: + self.torch_amp_kwargs = dict( + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + ) + + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable]: model = TorchAMPModule(model) if optimizer is not None: optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py index f48bf38bd724..62f3708fc629 100644 --- a/colossalai/booster/plugin/__init__.py +++ b/colossalai/booster/plugin/__init__.py @@ -4,11 +4,12 @@ from .plugin_base import Plugin from .torch_ddp_plugin import TorchDDPPlugin -__all__ = ['Plugin', 'TorchDDPPlugin', 'GeminiPlugin', 'LowLevelZeroPlugin', 'HybridParallelPlugin'] +__all__ = ["Plugin", "TorchDDPPlugin", "GeminiPlugin", "LowLevelZeroPlugin", "HybridParallelPlugin"] import torch from packaging import version -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): from .torch_fsdp_plugin import TorchFSDPPlugin - __all__.append('TorchFSDPPlugin') + + __all__.append("TorchFSDPPlugin") diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py index d5da5938bfd9..d2dd00453e32 100644 --- a/colossalai/booster/plugin/dp_plugin_base.py +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -10,25 +10,19 @@ class DPPluginBase(Plugin): - """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation. - """ + """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation.""" def __init__(self) -> None: super().__init__() - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + assert ( + dist.is_initialized() + ), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment" self.rank = dist.get_rank() self.world_size = dist.get_world_size() - def prepare_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): + def prepare_dataloader( + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. @@ -60,11 +54,13 @@ def seed_worker(worker_id): torch.manual_seed(worker_seed) random.seed(worker_seed) - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index de03ba27bfda..83a00d4ee229 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -27,14 +27,13 @@ from .dp_plugin_base import DPPluginBase -__all__ = ['GeminiPlugin'] +__all__ = ["GeminiPlugin"] -SUPPORTED_PRECISION = ['fp16', 'bf16'] -PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16} +SUPPORTED_PRECISION = ["fp16", "bf16"] +PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} class GeminiCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() @@ -74,13 +73,15 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): """ super().load_unsharded_optimizer(optimizer, checkpoint) - def save_sharded_model(self, - model: GeminiDDP, - checkpoint_path: str, - gather_dtensor: bool = False, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False): + def save_sharded_model( + self, + model: GeminiDDP, + checkpoint_path: str, + gather_dtensor: bool = False, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): """ Save sharded model. As there is communication when getting state dict, model.state_dict() must be called on all processes. @@ -97,34 +98,37 @@ def save_sharded_model(self, # Save shards of optimizer states. is_master = self.coordinator.is_master() - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=is_master, - use_safetensors=use_safetensors) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=is_master, + use_safetensors=use_safetensors, + ) # only save the index file on the master rank if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model.module, checkpoint_path) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") - - def load_sharded_model(self, - model: GeminiDDP, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_model( + self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False + ): """ Load shard model, load model from multiple files. """ return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer( + self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + ): """ Save sharded optimizer state dict to checkpoint folder. As there is communication when getting state dict, this must be called on all processes. @@ -153,20 +157,24 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_ # Save shards of optimizer states. is_master = self.coordinator.is_master() - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=is_master, - use_safetensors=False) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=is_master, + use_safetensors=False, + ) # Wrap up index file. Only save it on master rank. if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str): """ @@ -185,8 +193,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Pa # Load param_groups. param_group_path = ckpt_index_file.get_param_group_filename() if param_group_path is None: - raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ - Lacking param group file under current directory.') + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) saved_param_groups = torch.load(param_group_path) optimizer.load_param_groups(saved_param_groups) @@ -274,11 +284,11 @@ def __init__( chunk_config_dict: Optional[dict] = None, chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", - shard_param_frac: float = 1.0, # only for static placement - offload_optim_frac: float = 0.0, # only for static placement - offload_param_frac: float = 0.0, # only for static placement - warmup_non_model_data_ratio: float = 0.8, # only for auto placement - steady_cuda_cap_ratio: float = 0.9, # only for auto placement + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + warmup_non_model_data_ratio: float = 0.8, # only for auto placement + steady_cuda_cap_ratio: float = 0.9, # only for auto placement precision: str = "fp16", pin_memory: bool = False, force_outputs_fp32: bool = False, @@ -300,7 +310,7 @@ def __init__( verbose: bool = False, ) -> None: super().__init__() - assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' + assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" self.gemini_config = dict( chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_current_device()), @@ -319,16 +329,20 @@ def __init__( memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], ) - self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) - self.optim_kwargs = dict(initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type) + self.zero_optim_config = dict( + gpu_margin_mem_ratio=gpu_margin_mem_ratio, + ) + self.optim_kwargs = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type, + ) self.verbose = verbose def support_no_sync(self) -> bool: @@ -344,7 +358,7 @@ def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def configure( self, @@ -354,7 +368,6 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -368,13 +381,10 @@ def configure( # wrap the model with Gemini model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose) - if optimizer is not None and \ - not isinstance(optimizer, OptimizerWrapper): - optimizer = GeminiOptimizer(optimizer, - model.unwrap(), - **self.zero_optim_config, - **self.optim_kwargs, - verbose=self.verbose) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + optimizer = GeminiOptimizer( + optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose + ) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d15245523226..c1693fa8d3a1 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -37,10 +37,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): - - def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict, custom_policy: Policy) -> None: - + def __init__( + self, + module: Module, + precision: str, + shard_config: ShardConfig, + dp_group: ProcessGroup, + use_ddp: bool, + ddp_config: dict, + custom_policy: Policy, + ) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group @@ -54,13 +60,14 @@ def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp for shared_param in self.shared_params: if len(shared_param) > 0: self.shared_param_process_groups.append( - self.stage_manager.init_process_group_by_stages(list(shared_param.keys()))) + self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + ) # setting mixed_precision self.mixed_precision = None - if precision == 'fp16': + if precision == "fp16": self.mixed_precision = torch.float16 - elif precision == 'bf16': + elif precision == "bf16": self.mixed_precision = torch.bfloat16 if self.mixed_precision is not None: module = module.to(self.mixed_precision) @@ -123,22 +130,21 @@ def get_param_info(optim: Optimizer): if optim is None: return {} - param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} + param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} start_index = 0 for group in optim.param_groups: + packed_group = {k: v for k, v in group.items() if k != "params"} + packed_group["params"] = [] - packed_group = {k: v for k, v in group.items() if k != 'params'} - packed_group['params'] = [] - - for param_id, param in enumerate(group['params'], start_index): + for param_id, param in enumerate(group["params"], start_index): original_shape = param.shape if isinstance(param, torch.Tensor) else None - packed_group['params'].append(param_id) - param_info['param2id'][id(param)] = param_id - param_info['id2param'][param_id] = id(param) - param_info['param2shape'][id(param)] = original_shape + packed_group["params"].append(param_id) + param_info["param2id"][id(param)] = param_id + param_info["id2param"][param_id] = id(param) + param_info["param2shape"][id(param)] = original_shape - param_info['param_groups'].append(packed_group) - start_index += len(group['params']) + param_info["param_groups"].append(packed_group) + start_index += len(group["params"]) return param_info @@ -147,13 +153,12 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): model_params = set(model.parameters()) new_param_groups = [] for group in optim.param_groups: - params = [p for p in group['params'] if p in model_params] - new_param_groups.append({**group, 'params': params}) - optim.__setstate__({'param_groups': new_param_groups}) + params = [p for p in group["params"] if p in model_params] + new_param_groups.append({**group, "params": params}) + optim.__setstate__({"param_groups": new_param_groups}) class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): self.param_info = param_info if use_pipeline: @@ -162,60 +167,87 @@ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_in class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): - - def __init__(self, - optim: Optimizer, - model: Module, - use_pipeline: bool, - param_info: OrderedDict, - precision: str = 'fp16', - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0): + def __init__( + self, + optim: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + ): self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) - super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, - hysteresis, max_scale, max_norm) + super().__init__( + optim, + precision, + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, + max_norm, + ) class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): - def __init__( - self, - optimizer: Optimizer, - model: Module, - use_pipeline: bool, - param_info: OrderedDict, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2., - backoff_factor: float = .5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp - forced_dtype: Optional[torch.dtype] = None): + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None, + ): self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optimizer, model) - super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, - hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, - overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, - forced_dtype) + super().__init__( + optimizer, + initial_scale, + min_scale, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, + max_scale, + clip_grad_norm, + verbose, + reduce_bucket_size, + communication_dtype, + overlap_communication, + partition_grad, + cpu_offload, + dp_process_group, + tp_process_group, + forced_dtype, + ) class HybridParallelPlugin(PipelinePluginBase): @@ -276,46 +308,47 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. """ - def __init__(self, - tp_size: int, - pp_size: int, - precision: str = 'fp16', - zero_stage: int = 0, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - enable_sequence_parallelism: bool = False, - enable_sequence_overlap: bool = False, - num_microbatches: Optional[int] = None, - microbatch_size: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - broadcast_buffers: bool = True, - ddp_bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False, - zero_bucket_size_in_m: int = 12, - cpu_offload: bool = False, - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - custom_policy: Policy = None) -> None: - + def __init__( + self, + tp_size: int, + pp_size: int, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None, + ) -> None: super().__init__() - assert dist.get_world_size() % ( - tp_size * pp_size - ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" if enable_sequence_parallelism: - assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" self.tp_size = tp_size self.pp_size = pp_size @@ -334,24 +367,28 @@ def __init__(self, self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' - assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) - self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, - num_microbatches=num_microbatches, - microbatch_size=microbatch_size) + self.schedule = OneForwardOneBackwardSchedule( + self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size + ) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) - self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, - pipeline_stage_manager=self.stage_manager, - enable_tensor_parallelism=self.tp_size > 1, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism, - enable_sequence_overlap=enable_sequence_overlap) + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap, + ) self.amp_config = dict( initial_scale=initial_scale, growth_factor=growth_factor, @@ -362,18 +399,22 @@ def __init__(self, max_scale=max_scale, ) - self.ddp_config = dict(broadcast_buffers=broadcast_buffers, - bucket_cap_mb=ddp_bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph) + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) - self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - cpu_offload=cpu_offload, - partition_grad=(self.zero_stage == 2)) + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + ) self.max_norm = max_norm @@ -382,10 +423,10 @@ def enable_pipeline_parallelism(self) -> bool: return self.pp_size > 1 def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def supported_precisions(self) -> List[str]: - return ['fp16', 'bf16', 'fp32'] + return ["fp16", "bf16", "fp32"] def control_device(self) -> bool: return True @@ -410,57 +451,67 @@ def configure( param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 - model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config, self.custom_policy) + model = HybridParallelModule( + model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy + ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: - if self.precision in ['fp16', 'bf16']: - optimizer = HybridParallelAMPOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - precision=self.precision, - max_norm=self.max_norm, - **self.amp_config) - self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, - optimizer.master_to_working_map) + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config, + ) + self.checkpoint_io.link_master_and_working_param( + optimizer.working_to_master_map, optimizer.master_to_working_map + ) else: - optimizer = HybridParallelNaiveOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info) + optimizer = HybridParallelNaiveOptimizer( + optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info + ) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." - assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." - optimizer = HybridParallelZeroOptimizer(optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - dp_process_group=self.dp_group, - tp_process_group=self.tp_group, - verbose=True, - clip_grad_norm=self.max_norm, - **self.zero_config, - **self.amp_config) - self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, - optimizer._param_store.master_to_working_param) + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.zero_config, + **self.amp_config, + ) + self.checkpoint_io.link_master_and_working_param( + optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param + ) return model, optimizer, criterion, dataloader, lr_scheduler - def execute_pipeline(self, - data_iter: Iterator, - model: HybridParallelModule, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, - HybridParallelZeroOptimizer]] = None, - return_loss: bool = True, - return_outputs: bool = False) -> dict: - assert self.enable_pipeline_parallelism, 'pipeline parallelism is not enabled' + def execute_pipeline( + self, + data_iter: Iterator, + model: HybridParallelModule, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[ + Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer] + ] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: + assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" # return loss or outputs if needed ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() with ctx: - outputs = self.schedule.forward_backward_step(model, data_iter, criterion, optimizer, return_loss, - return_outputs) + outputs = self.schedule.forward_backward_step( + model, data_iter, criterion, optimizer, return_loss, return_outputs + ) model.sync_shared_params() if isinstance(optimizer, HybridParallelZeroOptimizer): optimizer.sync_grad() @@ -468,15 +519,9 @@ def execute_pipeline(self, model.sync_grads() return outputs - def prepare_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): + def prepare_dataloader( + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + ): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. @@ -499,10 +544,9 @@ def prepare_dataloader(self, :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. """ _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, - num_replicas=self.pg_mesh.size(DP_AXIS), - rank=self.pg_mesh.coordinate(DP_AXIS), - shuffle=shuffle) + sampler = DistributedSampler( + dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + ) # Deterministic dataloader def seed_worker(worker_id): @@ -511,14 +555,16 @@ def seed_worker(worker_id): torch.manual_seed(worker_seed) random.seed(worker_seed) - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) def get_checkpoint_io(self) -> CheckpointIO: self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 9adb4beec9b9..86adee7fe226 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,14 +1,12 @@ import logging import os -import warnings from functools import partial from pathlib import Path from types import MethodType -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple import torch import torch.nn as nn -from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map @@ -33,7 +31,7 @@ from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO -__all__ = ['LowLevelZeroPlugin'] +__all__ = ["LowLevelZeroPlugin"] def _convert_floating_point(x, dtype: torch.dtype = torch.float16): @@ -42,17 +40,16 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): return x -SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] +SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str) -> None: super().__init__(module) self.dtype = None - if precision == 'fp16': + if precision == "fp16": self.dtype = torch.float16 - elif precision == 'bf16': + elif precision == "bf16": self.dtype = torch.bfloat16 if self.dtype is not None: module = module.to(self.dtype) @@ -74,7 +71,6 @@ def unwrap(self): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): """Save optimizer to checkpoint but only on master process. @@ -91,12 +87,14 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors=False) - def save_sharded_optimizer(self, - optimizer: OptimizerWrapper, - checkpoint: str, - gather_dtensor: bool = False, - prefix: str = None, - size_per_shard: int = 1024): + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = False, + prefix: str = None, + size_per_shard: int = 1024, + ): """ Save sharded Zero-optimizer checkpoint under the given checkpointing path. The following files will be created under the path: @@ -148,9 +146,11 @@ def save_sharded_optimizer(self, index_file.append_meta_data("total_size", total_size) if self.coordinator.is_master(): index_file.write_index_file(save_index_file) - logging.info(f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): """Load sharded optimizer with the given path to index file. @@ -170,8 +170,10 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s # Load param_groups param_group_path = ckpt_index_file.get_param_group_filename() if param_group_path is None: - raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \ - Lacking param group file under current directory.') + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory." + ) id_map = load_param_groups_into_optimizer(optimizer, param_group_path) checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() @@ -181,9 +183,10 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s # shard state dict for param_idx, state in state_dict.items(): for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != 'step': - padding_size = (self.coordinator.world_size - - v.numel() % self.coordinator.world_size) % self.coordinator.world_size + if isinstance(v, torch.Tensor) and k != "step": + padding_size = ( + self.coordinator.world_size - v.numel() % self.coordinator.world_size + ) % self.coordinator.world_size with torch.no_grad(): v = v.flatten() if padding_size > 0: @@ -194,33 +197,39 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s sharded_optimizer_loading_epilogue(optimizer) - def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, - use_safetensors: bool): + def save_unsharded_model( + self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool + ): assert isinstance(model, LowLevelZeroModel) super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors) - def save_sharded_model(self, - model: nn.Module, - checkpoint_path: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False): + def save_sharded_model( + self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): assert isinstance(model, LowLevelZeroModel) - super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, - use_safetensors) + super().save_sharded_model( + model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + ) def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True): assert isinstance(model, LowLevelZeroModel) super().load_unsharded_model(model.module, checkpoint, strict) model.update_master_params() - def load_sharded_model(self, - model: LowLevelZeroModel, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True): + def load_sharded_model( + self, + model: LowLevelZeroModel, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): assert isinstance(model, LowLevelZeroModel) super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module) model.update_master_params() @@ -264,7 +273,7 @@ class LowLevelZeroPlugin(DPPluginBase): def __init__( self, stage: int = 1, - precision: str = 'fp16', + precision: str = "fp16", initial_scale: float = 2**32, min_scale: float = 1, growth_factor: float = 2, @@ -281,9 +290,9 @@ def __init__( verbose: bool = False, ) -> None: super().__init__() - assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' - assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' - assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now' + assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" + assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training" + assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now" self.stage = stage self.precision = precision self.zero_optim_kwargs = dict( @@ -319,7 +328,7 @@ def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def configure( self, @@ -329,15 +338,13 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - if not isinstance(model, ModelWrapper): model = LowLevelZeroModel(model, self.precision) - if optimizer is not None and \ - not isinstance(optimizer, OptimizerWrapper): - optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer, - **self.zero_optim_kwargs, - verbose=self.verbose) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( + optimizer, **self.zero_optim_kwargs, verbose=self.verbose + ) # inject update_master_params model.update_master_params = MethodType(optimizer.update_master_params, model) diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index fb21e57f41f7..4e570cbe8abc 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple import torch.nn as nn from torch.optim import Optimizer @@ -9,11 +9,10 @@ from colossalai.checkpoint_io import CheckpointIO from colossalai.interface import OptimizerWrapper -__all__ = ['Plugin'] +__all__ = ["Plugin"] class Plugin(ABC): - @abstractmethod def supported_devices(self) -> List[str]: pass @@ -51,33 +50,31 @@ def control_checkpoint_io(self) -> bool: """ Whether the plugin controls the checkpoint io """ - pass @abstractmethod def get_checkpoint_io(self) -> CheckpointIO: """ Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. """ - pass @abstractmethod def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: """ Context manager to disable gradient synchronization. """ - pass @abstractmethod - def prepare_dataloader(self, - dataset: Dataset, - batch_size: int, - shuffle: bool = False, - seed: int = 1024, - drop_last: bool = False, - pin_memory: bool = False, - num_workers: int = 0, - **kwargs): + def prepare_dataloader( + self, + dataset: Dataset, + batch_size: int, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = False, + num_workers: int = 0, + **kwargs, + ): """Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` """ - pass diff --git a/colossalai/booster/plugin/pp_plugin_base.py b/colossalai/booster/plugin/pp_plugin_base.py index f52844db082f..3d91eb95b409 100644 --- a/colossalai/booster/plugin/pp_plugin_base.py +++ b/colossalai/booster/plugin/pp_plugin_base.py @@ -9,13 +9,14 @@ class PipelinePluginBase(Plugin): - @abstractmethod - def execute_pipeline(self, - data_iter: Iterator, - model: ModelWrapper, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[OptimizerWrapper] = None, - return_loss: bool = True, - return_outputs: bool = False) -> dict: + def execute_pipeline( + self, + data_iter: Iterator, + model: ModelWrapper, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index f3f779c88e42..30d34e7dd5e5 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP @@ -12,11 +12,10 @@ from .dp_plugin_base import DPPluginBase -__all__ = ['TorchDDPPlugin'] +__all__ = ["TorchDDPPlugin"] class TorchDDPCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() @@ -49,25 +48,29 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - def save_sharded_model(self, - model: nn.Module, - checkpoint_path: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False): + def save_sharded_model( + self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): """ Save model to checkpoint but only on master process. """ if self.coordinator.is_master(): super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors) - def save_sharded_optimizer(self, - optimizer: Optimizer, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024): + def save_sharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): """ Save optimizer to checkpoint but only on master process. """ @@ -76,7 +79,6 @@ def save_sharded_optimizer(self, class TorchDDPModel(ModelWrapper): - def __init__(self, module: nn.Module, *args, **kwargs) -> None: super().__init__(module) self.module = DDP(module, *args, **kwargs) @@ -109,20 +111,24 @@ class TorchDDPPlugin(DPPluginBase): static_graph (bool, optional): Whether to use static graph. Defaults to False. """ - def __init__(self, - broadcast_buffers: bool = True, - bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False) -> None: + def __init__( + self, + broadcast_buffers: bool = True, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + ) -> None: super().__init__() - self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers, - bucket_cap_mb=bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph) + self.ddp_kwargs = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) def support_no_sync(self) -> bool: return True @@ -131,13 +137,13 @@ def control_precision(self) -> bool: return False def supported_precisions(self) -> List[str]: - return ['fp16', 'fp16_apex', 'bf16', 'fp8'] + return ["fp16", "fp16_apex", "bf16", "fp8"] def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def configure( self, @@ -156,8 +162,7 @@ def configure( # wrap the model with PyTorch DDP model = TorchDDPModel(model, **self.ddp_kwargs) - if optimizer is not None and \ - not isinstance(optimizer, OptimizerWrapper): + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = OptimizerWrapper(optimizer) return model, optimizer, criterion, dataloader, lr_scheduler @@ -169,5 +174,5 @@ def get_checkpoint_io(self) -> CheckpointIO: return TorchDDPCheckpointIO() def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.' + assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin." return model.module.no_sync() diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index fb7b5baadd0c..d12b784b4fc1 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,13 +1,13 @@ import warnings from pathlib import Path -from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterable, Iterator, List, Optional, Tuple import torch import torch.nn as nn from packaging import version from torch.distributed import ProcessGroup -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): from torch.distributed.fsdp import FullStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType @@ -31,11 +31,10 @@ from .dp_plugin_base import DPPluginBase -__all__ = ['TorchFSDPPlugin'] +__all__ = ["TorchFSDPPlugin"] class TorchFSDPCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() @@ -69,26 +68,36 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], - size_per_shard: int, use_safetensors: bool): + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + prefix: Optional[str], + size_per_shard: int, + use_safetensors: bool, + ): """ Save model to checkpoint but only on master process. """ raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def load_sharded_model(self, - model: nn.Module, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True): + def load_sharded_model( + self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): """ Load model to checkpoint but only on master process. """ raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer( + self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int + ): """ Save optimizer to checkpoint but only on master process. """ @@ -109,7 +118,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): class TorchFSDPModel(ModelWrapper): - def __init__(self, module: nn.Module, *args, **kwargs) -> None: super().__init__(module) self.module = FSDP(module, *args, **kwargs) @@ -119,7 +127,6 @@ def unwrap(self): class FSDPOptimizerWrapper(OptimizerWrapper): - def __init__(self, optimizer: Optimizer, model: nn.Module): self.model = model super().__init__(optimizer) @@ -147,7 +154,7 @@ class TorchFSDPPlugin(DPPluginBase): See https://pytorch.org/docs/stable/fsdp.html for details. """ - if version.parse(torch.__version__) >= version.parse('1.12.0'): + if version.parse(torch.__version__) >= version.parse("1.12.0"): def __init__( self, @@ -162,15 +169,18 @@ def __init__( sync_module_states: bool = False, ): super().__init__() - self.fsdp_kwargs = dict(process_group=process_group, - sharding_strategy=sharding_strategy, - cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - ignored_modules=ignored_modules, - param_init_fn=param_init_fn, - sync_module_states=sync_module_states) + self.fsdp_kwargs = dict( + process_group=process_group, + sharding_strategy=sharding_strategy, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=backward_prefetch, + mixed_precision=mixed_precision, + ignored_modules=ignored_modules, + param_init_fn=param_init_fn, + sync_module_states=sync_module_states, + ) + else: raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") @@ -184,13 +194,13 @@ def control_precision(self) -> bool: return True def supported_precisions(self) -> List[str]: - return ['fp16', 'bf16'] + return ["fp16", "bf16"] def control_device(self) -> bool: return True def supported_devices(self) -> List[str]: - return ['cuda'] + return ["cuda"] def configure( self, @@ -200,14 +210,13 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - # wrap the model with PyTorch FSDP fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) if optimizer is not None: if len(optimizer.param_groups) > 1: warnings.warn( - 'TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used.' + "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." ) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index e1aa6543ef39..19b61730bded 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -3,4 +3,4 @@ from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO from .index_file import CheckpointIndexFile -__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] +__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"] diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index baff24e1cb25..f8ce8f4e5210 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -11,7 +11,7 @@ from .utils import has_index_file -__all__ = ['CheckpointIO'] +__all__ = ["CheckpointIO"] class CheckpointIO(ABC): @@ -61,10 +61,9 @@ class CheckpointIO(ABC): # ====================================== # Public methods # ====================================== - def load_model(self, - model: Union[nn.Module, ModelWrapper], - checkpoint: str, - strict: bool = True) -> Union[nn.Module, ModelWrapper]: + def load_model( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True + ) -> Union[nn.Module, ModelWrapper]: """ Load model from checkpoint. @@ -98,14 +97,16 @@ def load_model(self, return origin_model - def save_model(self, - model: Union[nn.Module, ModelWrapper], - checkpoint: str, - shard: bool = False, - gather_dtensor: bool = True, - prefix: str = None, - size_per_shard: int = 1024, - use_safetensors: bool = False): + def save_model( + self, + model: Union[nn.Module, ModelWrapper], + checkpoint: str, + shard: bool = False, + gather_dtensor: bool = True, + prefix: str = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ): """ Save model to checkpoint. @@ -157,7 +158,7 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No if Path(checkpoint).is_dir() and not index_file_exists: # if the checkpoint is a directory and there is no index file, raise error - raise ValueError(f'Cannot find index file in {checkpoint}') + raise ValueError(f"Cannot find index file in {checkpoint}") if index_file_exists: # the existence of index file means it is a sharded checkpoint @@ -165,13 +166,15 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No else: self.load_unsharded_optimizer(optimizer, checkpoint) - def save_optimizer(self, - optimizer: Optimizer, - checkpoint: str, - shard: bool = False, - gather_dtensor=True, - prefix: str = None, - size_per_shard: int = 1024): + def save_optimizer( + self, + optimizer: Optimizer, + checkpoint: str, + shard: bool = False, + gather_dtensor=True, + prefix: str = None, + size_per_shard: int = 1024, + ): """ Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. @@ -207,7 +210,6 @@ def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: boo strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ - pass @abstractmethod def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): @@ -220,11 +222,17 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ - pass @abstractmethod - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str], - size_per_shard: int, use_safetensors: bool): + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool, + prefix: Optional[str], + size_per_shard: int, + use_safetensors: bool, + ): """ Save model to sharded checkpoint. @@ -236,7 +244,6 @@ def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: size_per_shard (int): size per shard in MB. use_safetensors (bool): whether to use safe tensors. """ - pass @abstractmethod def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): @@ -249,7 +256,6 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor gather_dtensor (bool): whether to gather the distributed tensor to the first device. use_safetensors (bool): whether to use safe tensors. """ - pass # ======================================================== # Abstract methods for optimizer loading/saving implementation @@ -265,7 +271,6 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. prefix (str): prefix for the optimizer checkpoint. """ - pass @abstractmethod def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): @@ -276,11 +281,11 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. """ - pass @abstractmethod - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int): + def save_sharded_optimizer( + self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + ): """ Save optimizer to sharded checkpoint. @@ -291,7 +296,6 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_ prefix (str): prefix for the optimizer checkpoint. size_per_shard (int): size per shard in MB. """ - pass @abstractmethod def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): @@ -303,7 +307,6 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gathe checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. gather_dtensor (bool): whether to gather the distributed tensor to the first device. """ - pass # ============================================ # methods for loading and saving lr scheduler diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index faaf1d22722a..b0e593e90d8c 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -3,9 +3,8 @@ import os from functools import reduce from pathlib import Path -from typing import Iterator, Optional, OrderedDict, Tuple +from typing import Optional -import torch.distributed as dist import torch.nn as nn from torch.optim import Optimizer @@ -16,7 +15,6 @@ from .utils import ( get_model_base_filenames, get_optimizer_base_filenames, - get_shard_filename, is_safetensors_available, load_param_groups_into_optimizer, load_shard_state_dict, @@ -33,7 +31,7 @@ unwrap_optimizer, ) -__all__ = ['GeneralCheckpointIO'] +__all__ = ["GeneralCheckpointIO"] class GeneralCheckpointIO(CheckpointIO): @@ -70,8 +68,10 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre # Load param_groups param_group_path = ckpt_index_file.get_param_group_filename() if param_group_path is None: - raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \ - Lacking param group file under current directory.') + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory." + ) id_map = load_param_groups_into_optimizer(optimizer, param_group_path) checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() @@ -123,19 +123,23 @@ def save_sharded_optimizer( # Save shards of optimizer states. # In general cases, is_master is set to True to get the right behavior. - total_size = save_state_dict_shards(sharded_state_dict=sharded_state, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=True, - use_safetensors=False) + total_size = save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + use_safetensors=False, + ) # Wrap up index file. index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): checkpoint = load_state_dict(checkpoint) @@ -150,13 +154,15 @@ def save_unsharded_optimizer( # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) - def save_sharded_model(self, - model: nn.Module, - checkpoint_path: str, - gather_dtensor: bool = False, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False): + def save_sharded_model( + self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = False, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): """ implement this method as it can be supported by Huggingface model, save shard model, save model to multiple files @@ -175,26 +181,32 @@ def save_sharded_model(self, # Save shards of optimizer states. # In general cases, is_master is set to True to get the right behavior. - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=True, - use_safetensors=use_safetensors) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=True, + use_safetensors=use_safetensors, + ) index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model, checkpoint_path, is_master=True) - logging.info(f"The model is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") - - def load_sharded_model(self, - model: nn.Module, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True): + logging.info( + f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_model( + self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): """ load shard model, load model from multiple files """ @@ -219,7 +231,11 @@ def load_sharded_model(self, if strict: remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) if len(remain_keys) > 0: - error_msgs = 'Missing key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in missing_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + error_msgs = "Missing key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in missing_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 270fd8564754..18c59a880dd6 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,5 +1,4 @@ import copy -import gc import logging import os from pathlib import Path @@ -35,9 +34,9 @@ ) try: - from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" class HybridParallelCheckpointIO(GeneralCheckpointIO): @@ -52,12 +51,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. """ - def __init__(self, - dp_group: ProcessGroup, - pp_group: ProcessGroup, - tp_group: ProcessGroup, - zero_stage: int, - verbose: bool = True) -> None: + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True, + ) -> None: super().__init__() self.dp_group = dp_group self.pp_group = pp_group @@ -68,17 +69,16 @@ def __init__(self, self.dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) - self.use_zero = (zero_stage > 0) + self.use_zero = zero_stage > 0 self.verbose = verbose self.working_to_master_map = None self.master_to_working_map = None self.coordinator = DistCoordinator() @staticmethod - def _model_sharder(model: nn.Module, - prefix: str = '', - keep_vars: bool = False, - size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + def _model_sharder( + model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024 + ) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. state_dict_sharder = StateDictSharder(size_per_shard) @@ -103,8 +103,10 @@ def _model_sharder(model: nn.Module, # Save extra states. extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(model.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): extra_state = model.get_extra_state() block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: @@ -114,20 +116,20 @@ def _model_sharder(model: nn.Module, yield state_dict_sharder.current_block, state_dict_sharder.current_block_size @staticmethod - def _optimizer_sharder(optimizer: OptimizerWrapper, - use_zero: bool, - dp_group: ProcessGroup, - tp_group: ProcessGroup, - master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, - size_per_shard: int = 1024): - + def _optimizer_sharder( + optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, + size_per_shard: int = 1024, + ): # An internel method that breaks state_dict of optimizer into shards within limited size. state_dict_sharder = StateDictSharder(size_per_shard) param_info = optimizer.param_info for param, state in optimizer.optim.state.items(): - if param is None: continue @@ -136,15 +138,17 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, else: working_param = param - param_id = param_info['param2id'][id(working_param)] - original_shape = param_info['param2shape'][id(working_param)] - state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, - working_param, - original_shape=original_shape, - dp_group=dp_group, - tp_group=tp_group, - use_zero=use_zero, - inplace=False) + param_id = param_info["param2id"][id(working_param)] + original_shape = param_info["param2shape"][id(working_param)] + state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False, + ) block, block_size = state_dict_sharder.append_optim_state(param_id, state_) if block is not None: @@ -153,13 +157,15 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, # Return the last block in sharder. yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - def save_sharded_model(self, - model: nn.Module, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False) -> None: + def save_sharded_model( + self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: """ Save sharded model checkpoint under the given checkpointing path. The following files will be created under the path: @@ -194,24 +200,28 @@ def save_sharded_model(self, state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - control_saving = (self.tp_rank == 0) + control_saving = self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) if control_saving: index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model, checkpoint) if self.verbose: - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -228,15 +238,19 @@ def save_sharded_model(self, save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - use_pp_format=True) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True, + ) if control_saving: - assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + assert ( + self.dp_rank == 0 and self.tp_rank == 0 + ), "The saving process should have both dp_rank and tp_rank as 0." index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) else: @@ -259,9 +273,11 @@ def save_sharded_model(self, save_config_file(model, checkpoint) rmtree(tmp_index_file_folder) if self.verbose: - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): """ @@ -305,11 +321,9 @@ def _load(name: str): state_dict = load_shard_state_dict(Path(file_path), use_safetensors) missing_keys = [] - load_state_dict_into_model(model, - state_dict, - missing_keys=missing_keys, - strict=strict, - load_sub_module=True) + load_state_dict_into_model( + model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True + ) loaded_file.add(filename) # Load parameters. @@ -319,15 +333,17 @@ def _load(name: str): # Load buffers. non_persistent_buffers = set() for n, m in model.named_modules(): - non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set) + non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set) for name, buf in model.named_buffers(): if buf is not None and name not in non_persistent_buffers: _load(name) # Load extra states. extra_state_key = _EXTRA_STATE_KEY_SUFFIX - if getattr(model.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): _load(extra_state_key) # Update master params if mixed-precision training is enabled. @@ -352,12 +368,14 @@ def _load(name: str): if self.verbose: logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") - def save_sharded_optimizer(self, - optimizer: OptimizerWrapper, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024): + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): """ Save sharded optimizer checkpoint under the given checkpointing path. The following files will be created under the path: @@ -393,18 +411,21 @@ def save_sharded_optimizer(self, dp_group=self.dp_group, tp_group=self.tp_group, master_to_working_map=self.master_to_working_map, - size_per_shard=size_per_shard) + size_per_shard=size_per_shard, + ) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) index_file = CheckpointIndexFile(checkpoint) - control_saving = (self.dp_rank == 0 and self.tp_rank == 0) + control_saving = self.dp_rank == 0 and self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the optimizer shards as in general checkpointIO - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) if control_saving: # Store param groups. @@ -415,9 +436,11 @@ def save_sharded_optimizer(self, index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) if self.verbose: - logging.info(f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -433,15 +456,19 @@ def save_sharded_optimizer(self, save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - use_pp_format=True) + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) if control_saving: - assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + assert ( + self.dp_rank == 0 and self.tp_rank == 0 + ), "The saving process should have both dp_rank and tp_rank as 0." index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) else: @@ -451,7 +478,6 @@ def save_sharded_optimizer(self, # The global master rank integrates the index files and clean the folder. if self.pp_rank == 0: - final_index_file = CheckpointIndexFile(checkpoint) final_index_file.append_meta_data("total_size", 0) @@ -470,9 +496,11 @@ def save_sharded_optimizer(self, rmtree(tmp_index_file_folder) if self.verbose: - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): """ @@ -484,20 +512,21 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f prefix (str): Not used. """ - def _get_param_id_from_optimizer_param(param: torch.Tensor, - master_to_working_map: Optional[Dict[int, torch.Tensor]] = None): + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): if master_to_working_map is not None: working_param = master_to_working_map[id(param)] else: working_param = param - return optimizer.param_info['param2id'][id(working_param)] + return optimizer.param_info["param2id"][id(working_param)] # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. # When Zero is used, the mapped parameter objects should be fp32 master parameters. # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. id_map = {} for pg in optimizer.optim.param_groups: - for param in pg['params']: + for param in pg["params"]: param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) id_map[param_id] = param @@ -505,28 +534,30 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) ckpt_root_path = ckpt_index_file.root_path weight_map = ckpt_index_file.weight_map - weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int # Load param_groups param_group_path = ckpt_index_file.get_param_group_filename() if param_group_path is None: - raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ - Lacking param group file under current directory.') + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) saved_groups = torch.load(param_group_path) updated_groups = [] for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): # obtain updated param group new_pg = copy.deepcopy(saved_pg) - new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change. + new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. updated_groups.append(new_pg) - optimizer.optim.__dict__.update({'param_groups': updated_groups}) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) # Load saved states to optimizer. # Keep a record of loaded files so that file will not be repeatedly loaded. loaded_file = set() for pg in optimizer.optim.param_groups: - for param in pg['params']: + for param in pg["params"]: if param is None: continue param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) @@ -550,12 +581,10 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, working_param = self.master_to_working_map[id(param)] else: working_param = param - original_shape = optimizer.param_info['param2shape'][id(working_param)] - sharded_state = self.shard_from_complete_optimizer_state(state, - current_shape=working_param.shape, - original_shape=original_shape, - device=device, - inplace=True) + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True + ) optimizer.optim.state[param] = sharded_state sharded_optimizer_loading_epilogue(optimizer.optim) @@ -585,8 +614,11 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], - master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]): + def link_master_and_working_param( + self, + working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], + master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor], + ): """ Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. This mapping can only be created when mixied precision is used. @@ -604,7 +636,8 @@ def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, t self.working_to_master_map[k] = v else: raise ValueError( - f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!" + ) self.master_to_working_map = dict() for k, v in master_to_working_map.items(): @@ -614,12 +647,19 @@ def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, t self.master_to_working_map[k] = v else: raise ValueError( - f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!" + ) @staticmethod - def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, - dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, - inplace: bool) -> OrderedDict: + def gather_from_sharded_optimizer_state( + state: OrderedDict, + param: torch.Tensor, + original_shape: torch.Size, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + use_zero: bool, + inplace: bool, + ) -> OrderedDict: """ With given parameter and its optimizer states, gather the complete optimizer state for saving. @@ -641,14 +681,13 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, state_ = state if inplace else copy.deepcopy(state) for k, v in state_.items(): - if isinstance(v, torch.Tensor) and k != 'step': - + if isinstance(v, torch.Tensor) and k != "step": # First gather Zero shards. if use_zero: v = v.cuda() gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] dist.all_gather(gather_tensor, v, group=dp_group) - v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) # Then gather TP shards. partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) @@ -661,9 +700,14 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, return state_ - def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size, - original_shape: torch.Size, device: torch.device, - inplace: bool) -> OrderedDict: + def shard_from_complete_optimizer_state( + self, + state: OrderedDict, + current_shape: torch.Size, + original_shape: torch.Size, + device: torch.device, + inplace: bool, + ) -> OrderedDict: """ With complete optimizer states of a specific parameter loaded from checkpoint, slice out the sharded optimizer states kept by current device. @@ -681,8 +725,7 @@ def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: state_ = state if inplace else copy.deepcopy(state) for k, v in state_.items(): - if isinstance(v, torch.Tensor) and k != 'step': - + if isinstance(v, torch.Tensor) and k != "step": # Shard state along tensor parallel group. partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) if partition_dim is not None: diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 388cf3fbe9bb..da12c146f2c3 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -6,7 +6,7 @@ from .utils import is_dtensor_checkpoint -__all__ = ['CheckpointIndexFile'] +__all__ = ["CheckpointIndexFile"] class CheckpointIndexFile: @@ -50,7 +50,7 @@ def load(self, json_path: str): json_path (str): path to the json file. """ # load the json file - with open(json_path, 'r') as f: + with open(json_path, "r") as f: index = json.load(f) # assign attributes if exists @@ -75,7 +75,7 @@ def export(self, json_path: str): index["weight_map"] = self.weight_map # export the index file - with open(json_path, 'w') as f: + with open(json_path, "w") as f: json.dump(index, f, indent=4) def append_weight_map(self, param_name: str, shard_file: str): diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 664ac63e45ac..c22b76dd46f7 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,5 +1,4 @@ # coding=utf-8 -import copy import os import re from collections import abc as container_abcs @@ -12,7 +11,7 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface import OptimizerWrapper from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, @@ -55,7 +54,6 @@ def is_safetensors_available() -> bool: bool: whether safetensors is available. """ try: - import safetensors return True except ImportError: return False @@ -71,7 +69,7 @@ def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool: Returns: bool: whether the checkpoint file is a dtensor checkpoint. """ - if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'): + if checkpoint_file_path.endswith(".*.safetensors") or checkpoint_file_path.endswith(".*.bin"): return True else: return False @@ -87,7 +85,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: Returns: bool: whether the checkpoint file is a safetensor checkpoint. """ - if checkpoint_file_path.endswith('.safetensors'): + if checkpoint_file_path.endswith(".safetensors"): return True else: return False @@ -113,8 +111,9 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz partition_dim = dim break if partition_dim is not None: - assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ - f"The parameter isn't evenly distributed among tensor parallel group: \ + assert ( + original_shape[partition_dim] == tp_size * current_shape[partition_dim] + ), f"The parameter isn't evenly distributed among tensor parallel group: \ shape before sharding {original_shape}, shape after sharding {current_shape}" return partition_dim @@ -124,24 +123,22 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz # Helper classes and functions for saving shard file # ====================================== def unwrap_optimizer(optimizer: OptimizerWrapper): - ''' + """ Unwrap a wrapped optimizer. This method should be used before saving/loading it to/from sharded checkpoints. - ''' + """ unwrapped_optim = optimizer.optim return unwrapped_optim class StateDictSharder: - def __init__(self, size_per_shard: int) -> None: self.max_shard_size = size_per_shard self.current_block = OrderedDict() self.current_block_size = 0 def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: - tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 @@ -159,13 +156,11 @@ def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[Ordere return ret_block, ret_block_size def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: - # A state might contain more than one tensors. # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' state_size = 0 isDTensor = False for state_tensor in state.values(): - # When state_tensor is not of Tensor class, # e.g., a SGD optimizer with momentum set to 0 can have None as state # The calculation of tensor size should be skipped to avoid error. @@ -217,14 +212,16 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to return param_ -def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], - checkpoint: str, - index_file: "CheckpointIndexFile", - base_filename: str, - is_master: bool, - use_safetensors: bool = False, - use_pp_format: bool = False) -> int: - ''' +def save_state_dict_shards( + sharded_state_dict: Iterator[Tuple[OrderedDict, int]], + checkpoint: str, + index_file: "CheckpointIndexFile", + base_filename: str, + is_master: bool, + use_safetensors: bool = False, + use_pp_format: bool = False, +) -> int: + """ Save sharded state dict only on master rank, this method can be used by both model and optimizer states. Args: sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size. @@ -237,7 +234,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] Returns: int: the total size of shards - ''' + """ total_size = 0 shard_filenames = [] @@ -288,7 +285,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> """ # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. - states = state_dict['state'] + states = state_dict["state"] state_dict_sharder = StateDictSharder(max_shard_size) for param_id, state in states.items(): @@ -316,9 +313,11 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors """ if use_safetensors: assert is_safetensors_available(), "safetensors is not available." - assert checkpoint_file_path.endswith('.safetensors'), \ - "safetensors only supports .safetensors suffix for checkpoint file." + assert checkpoint_file_path.endswith( + ".safetensors" + ), "safetensors only supports .safetensors suffix for checkpoint file." from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) else: torch.save(state_dict, checkpoint_file_path) @@ -336,11 +335,13 @@ def save_param_groups(state_dict: dict, group_file_path: str) -> None: torch.save(param_groups, group_file_path) -def clean_folder(checkpoint_path: str, - weights_name: str, - shard_filenames: List[str], - is_master: bool = True, - use_pp_format: bool = False): +def clean_folder( + checkpoint_path: str, + weights_name: str, + shard_filenames: List[str], + is_master: bool = True, + use_pp_format: bool = False, +): """ Clean the unneeded files in checkpoint directory after shards of state_dict have been saved. @@ -362,8 +363,12 @@ def clean_folder(checkpoint_path: str, else: # When this checkpoint is created by pipeline parallel process, the pattern is a little different. reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}") - if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) - and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None): + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shard_filenames + and reg.fullmatch(filename_no_suffix) is not None + ): os.remove(full_filename) @@ -412,7 +417,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi size_per_shard (int): size per shard in MB. """ root_path = index_file.root_path - output_root_path = root_path.joinpath('dtensor') + output_root_path = root_path.joinpath("dtensor") # create directory output_root_path.mkdir(exist_ok=True) @@ -432,7 +437,7 @@ def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFi # update the weight map # * means all shards - ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + ckpt_file_name_in_weight_map = "dtensor/" + generate_dtensor_file_name(name, "*", use_safetensors) index_file.append_weight_map(name, ckpt_file_name_in_weight_map) @@ -447,15 +452,14 @@ def get_checkpoint_file_suffix(use_safetensors: bool) -> str: str: checkpoint file suffix. """ if use_safetensors: - return '.safetensors' + return ".safetensors" else: - return '.bin' + return ".bin" -def generate_checkpoint_shard_file_name(index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None) -> str: +def generate_checkpoint_shard_file_name( + index: int, total_number: int, use_safetensors: bool, prefix: str = None +) -> str: """ Generate checkpoint shard file name. @@ -489,7 +493,7 @@ def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: boo str: dtensor file name. """ suffix = get_checkpoint_file_suffix(use_safetensors) - return f'{param_name}.{index}.{suffix}' + return f"{param_name}.{index}.{suffix}" # ======================================== @@ -506,21 +510,21 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): if use_safetensors: from safetensors.torch import load_file as safe_load_file from safetensors.torch import safe_open + with safe_open(checkpoint_file, framework="pt") as f: metadata = f.metadata() if metadata["format"] != "pt": raise NotImplementedError( - f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.") + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." + ) return safe_load_file(checkpoint_file) else: - return torch.load(checkpoint_file, map_location=torch.device('cpu')) + return torch.load(checkpoint_file, map_location=torch.device("cpu")) -def load_state_dict_into_model(model: nn.Module, - state_dict: torch.Tensor, - missing_keys: List, - strict: bool = False, - load_sub_module: bool = True): +def load_state_dict_into_model( + model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True +): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. @@ -536,7 +540,7 @@ def load_state_dict_into_model(model: nn.Module, error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) + metadata = getattr(state_dict, "_metadata", None) state_dict = OrderedDict(state_dict) if metadata is not None: state_dict._metadata = metadata @@ -560,10 +564,12 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True) if strict: if len(unexpected_keys) > 0: - error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in unexpected_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) + error_msgs = "Unexpected key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in unexpected_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) + ) def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str) -> dict: @@ -573,9 +579,9 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str # Load list of param_groups from given file path. # The params in saved_groups are in the form of integer indices. - saved_groups = torch.load(param_group_path, map_location=torch.device('cpu')) + saved_groups = torch.load(param_group_path, map_location=torch.device("cpu")) if not isinstance(saved_groups, List): - raise ValueError(f'The param_groups saved at {param_group_path} is not of List type') + raise ValueError(f"The param_groups saved at {param_group_path} is not of List type") # The params in param_groups are in the form of pytorch tensors. # For more details, please view source code of Optimizer class in pytorch. @@ -584,26 +590,30 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str # Check the compatibility of saved_groups and param_groups. if len(param_groups) != len(saved_groups): raise ValueError("loaded state dict has a different number of original parameter groups") - param_lens = (len(g['params']) for g in param_groups) - saved_lens = (len(g['params']) for g in saved_groups) + param_lens = (len(g["params"]) for g in param_groups) + saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") + raise ValueError( + "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" + ) # Creating mapping from id to parameters. id_map = { - old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups - )), chain.from_iterable((g['params'] for g in param_groups))) + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in param_groups)), + ) } # Update parameter groups, setting their 'params' value. def update_group(group, new_group): - new_group['params'] = group['params'] + new_group["params"] = group["params"] return new_group updated_groups = [update_group(g, ng) for g, ng in zip(param_groups, saved_groups)] - optimizer.__dict__.update({'param_groups': updated_groups}) + optimizer.__dict__.update({"param_groups": updated_groups}) return id_map @@ -628,7 +638,7 @@ def cast(param, value, key=None): # Floating-point types are a bit special here. They are the only ones # that are assumed to always match the type of params. # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 - if (key != "step"): + if key != "step": if param.is_floating_point(): value = value.to(param.dtype) value = value.to(param.device) @@ -662,8 +672,8 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer): """ # Do the cleaning up as in src code of Pytorch. - optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. - optimizer.defaults.setdefault('differentiable', False) + optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle. + optimizer.defaults.setdefault("differentiable", False) def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: @@ -686,20 +696,20 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: return False, None elif checkpoint_path.is_dir(): # check if there is only one a file ending with .index.json in this directory - index_files = list(checkpoint_path.glob('*.index.*json')) + index_files = list(checkpoint_path.glob("*.index.*json")) # if we found a .index.json file, make sure there is only one if len(index_files) > 0: - assert len( - index_files - ) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}' + assert ( + len(index_files) == 1 + ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}" if len(index_files) == 1: return True, index_files[0] else: return False, None else: - raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.') + raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.") def load_state_dict(checkpoint_file_path: Path): @@ -713,14 +723,17 @@ def load_state_dict(checkpoint_file_path: Path): dict: state dict. """ - assert not is_dtensor_checkpoint(checkpoint_file_path), \ - f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.' + assert not is_dtensor_checkpoint( + checkpoint_file_path + ), f"Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline." if is_safetensor_checkpoint(checkpoint_file_path): - assert is_safetensors_available(), \ - f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.' + assert ( + is_safetensors_available() + ), f"Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors." # load with safetensors from safetensors import safe_open + state_dict = {} with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: for k in f.keys(): @@ -729,7 +742,7 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch - return torch.load(checkpoint_file_path, map_location=torch.device('cpu')) + return torch.load(checkpoint_file_path, map_location=torch.device("cpu")) def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: diff --git a/colossalai/cli/__init__.py b/colossalai/cli/__init__.py index 658e35e4c72e..c7cb19c19308 100644 --- a/colossalai/cli/__init__.py +++ b/colossalai/cli/__init__.py @@ -1,3 +1,3 @@ from .cli import cli -__all__ = ['cli'] +__all__ = ["cli"] diff --git a/colossalai/cli/check/__init__.py b/colossalai/cli/check/__init__.py index a86b32bb6a18..7c26ab6ade6c 100644 --- a/colossalai/cli/check/__init__.py +++ b/colossalai/cli/check/__init__.py @@ -1,11 +1,12 @@ import click + from .check_installation import check_installation -__all__ = ['check'] +__all__ = ["check"] @click.command(help="Check if Colossal-AI is correct based on the given option") -@click.option('-i', '--installation', is_flag=True, help="Check if Colossal-AI is built correctly") +@click.option("-i", "--installation", is_flag=True, help="Check if Colossal-AI is built correctly") def check(installation): if installation: check_installation() diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py index 4a481f3bd122..772c513ffa06 100644 --- a/colossalai/cli/check/check_installation.py +++ b/colossalai/cli/check/check_installation.py @@ -9,7 +9,7 @@ def to_click_output(val): # installation check output to understandable symbols for readability - VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'} + VAL_TO_SYMBOL = {True: "\u2713", False: "x", None: "N/A"} if val in VAL_TO_SYMBOL: return VAL_TO_SYMBOL[val] @@ -55,8 +55,8 @@ def check_installation(): else: torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required]) - click.echo(f'#### Installation Report ####') - click.echo(f'\n------------ Environment ------------') + click.echo(f"#### Installation Report ####") + click.echo(f"\n------------ Environment ------------") click.echo(f"Colossal-AI version: {to_click_output(colossalai_version)}") click.echo(f"PyTorch version: {to_click_output(torch_version)}") click.echo(f"System CUDA version: {to_click_output(cuda_version)}") @@ -69,7 +69,7 @@ def check_installation(): f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version." ) - click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------') + click.echo(f"\n------------ CUDA Extensions AOT Compilation ------------") click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}") click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}") click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}") @@ -81,7 +81,7 @@ def check_installation(): click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime") click.echo(f"\n------------ Compatibility ------------") - click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}') + click.echo(f"PyTorch version match: {to_click_output(torch_compatibility)}") click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}") click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}") click.echo(f"") @@ -106,12 +106,12 @@ def _is_compatible(versions): return False # split version into [major, minor, patch] - versions = [version.split('.') for version in versions] + versions = [version.split(".") for version in versions] for version in versions: if len(version) == 2: # x means unknown - version.append('x') + version.append("x") for idx, version_values in enumerate(zip(*versions)): equal = len(set(version_values)) == 1 @@ -137,11 +137,11 @@ def _parse_colossalai_version(): # 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions) # 2. X.X.X (when colossalai is not installed with CUDA extensions) # where X represents an integer. - colossalai_version = colossalai.__version__.split('+')[0] + colossalai_version = colossalai.__version__.split("+")[0] try: - torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0] - cuda_version_for_aot_build = colossalai.__version__.split('cu')[1] + torch_version_for_aot_build = colossalai.__version__.split("torch")[1].split("cu")[0] + cuda_version_for_aot_build = colossalai.__version__.split("cu")[1] except: torch_version_for_aot_build = None cuda_version_for_aot_build = None @@ -156,7 +156,6 @@ def _check_aot_built_cuda_extension_installed(): JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime. """ try: - import colossalai._C.fused_optim found_aot_cuda_ext = True except ImportError: found_aot_cuda_ext = False @@ -175,14 +174,14 @@ def _check_torch_version(): # torch version can be of two formats # - 1.13.1+cu113 # - 1.13.1.devxxx - torch_version = torch.__version__.split('+')[0] - torch_version = '.'.join(torch_version.split('.')[:3]) + torch_version = torch.__version__.split("+")[0] + torch_version = ".".join(torch_version.split(".")[:3]) # get cuda version in pytorch build try: torch_cuda_major = torch.version.cuda.split(".")[0] torch_cuda_minor = torch.version.cuda.split(".")[1] - torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}' + torch_cuda_version = f"{torch_cuda_major}.{torch_cuda_minor}" except: torch_cuda_version = None @@ -208,7 +207,7 @@ def _check_cuda_version(): release = output[release_idx].split(".") bare_metal_major = release[0] bare_metal_minor = release[1][0] - cuda_version = f'{bare_metal_major}.{bare_metal_minor}' + cuda_version = f"{bare_metal_major}.{bare_metal_minor}" except: cuda_version = None return cuda_version diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py index 0dea7c504957..0d94fe59f8ae 100644 --- a/colossalai/cli/cli.py +++ b/colossalai/cli/cli.py @@ -4,8 +4,7 @@ from .launcher import run -class Arguments(): - +class Arguments: def __init__(self, arg_dict): for k, v in arg_dict.items(): self.__dict__[k] = v @@ -19,5 +18,5 @@ def cli(): cli.add_command(run) cli.add_command(check) -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py index 808e4e84574f..0f9ead6495db 100644 --- a/colossalai/cli/launcher/__init__.py +++ b/colossalai/cli/launcher/__init__.py @@ -5,56 +5,81 @@ from .run import launch_multi_processes -@click.command(help="Launch distributed training on a single node or multiple nodes", - context_settings=dict(ignore_unknown_options=True)) -@click.option("-H", - "-host", - "--host", - type=str, - default=None, - help="the list of hostnames to launch in the format ,") +@click.command( + help="Launch distributed training on a single node or multiple nodes", + context_settings=dict(ignore_unknown_options=True), +) +@click.option( + "-H", + "-host", + "--host", + type=str, + default=None, + help="the list of hostnames to launch in the format ,", +) @click.option( "--hostfile", type=str, default=None, - help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname") -@click.option("--include", - type=str, - default=None, - help="Specify computing devices to use during execution. String format is ,," - " only effective when used with --hostfile.") + help="Hostfile path that defines the device pool available to the job, each line in the file is a hostname", +) +@click.option( + "--include", + type=str, + default=None, + help="Specify computing devices to use during execution. String format is ,," + " only effective when used with --hostfile.", +) @click.option( "--exclude", type=str, default=None, - help= - "Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include," - " only effective when used with --hostfile.") -@click.option("--num_nodes", - type=int, - default=-1, - help="Total number of worker nodes to use, only effective when used with --hostfile.") + help="Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include," + " only effective when used with --hostfile.", +) +@click.option( + "--num_nodes", + type=int, + default=-1, + help="Total number of worker nodes to use, only effective when used with --hostfile.", +) @click.option("--nproc_per_node", type=int, default=None, help="Number of GPUs to use on each node.") -@click.option("--master_port", - type=int, - default=29500, - help="(optional) Port used by PyTorch distributed for communication during distributed training.") -@click.option("--master_addr", - type=str, - default="127.0.0.1", - help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.") +@click.option( + "--master_port", + type=int, + default=29500, + help="(optional) Port used by PyTorch distributed for communication during distributed training.", +) +@click.option( + "--master_addr", + type=str, + default="127.0.0.1", + help="(optional) IP address of node 0, will be inferred via 'hostname -I' if not specified.", +) @click.option( "--extra_launch_args", type=str, default=None, - help= - "Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " - "This will be converted to --arg1=1 --arg2=2 during execution") + help="Set additional torch distributed launcher arguments such as --standalone. The format is --extra_launch_args arg1=1,arg2=2. " + "This will be converted to --arg1=1 --arg2=2 during execution", +) @click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection") @click.argument("user_script", type=str) -@click.argument('user_args', nargs=-1) -def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: str, exclude: str, master_addr: str, - master_port: int, extra_launch_args: str, ssh_port: int, user_script: str, user_args: str) -> None: +@click.argument("user_args", nargs=-1) +def run( + host: str, + hostfile: str, + num_nodes: int, + nproc_per_node: int, + include: str, + exclude: str, + master_addr: str, + master_port: int, + extra_launch_args: str, + ssh_port: int, + user_script: str, + user_args: str, +) -> None: """ To launch multiple processes on a single node or multiple nodes via command line. @@ -77,8 +102,8 @@ def run(host: str, hostfile: str, num_nodes: int, nproc_per_node: int, include: # run with hostfile excluding the hosts selected colossalai run --hostfile --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py """ - if not user_script.endswith('.py'): - click.echo(f'Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help') + if not user_script.endswith(".py"): + click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help") exit() args_dict = locals() diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py index 2a6a111e4d72..684f64f59d28 100644 --- a/colossalai/cli/launcher/hostinfo.py +++ b/colossalai/cli/launcher/hostinfo.py @@ -1,5 +1,4 @@ import socket -from typing import List class HostInfo: @@ -34,7 +33,7 @@ def is_host_localhost(hostname: str, port: str = None) -> None: """ if port is None: - port = 22 # no port specified, lets just use the ssh port + port = 22 # no port specified, lets just use the ssh port # socket.getfqdn("127.0.0.1") does not return localhost # on some users' machines @@ -50,7 +49,7 @@ def is_host_localhost(hostname: str, port: str = None) -> None: return localaddrs == targetaddrs def __str__(self): - return f'hostname: {self.hostname}, port: {self.port}' + return f"hostname: {self.hostname}, port: {self.port}" def __repr__(self): return self.__str__() diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py index 85b241e96292..99c4db406844 100644 --- a/colossalai/cli/launcher/multinode_runner.py +++ b/colossalai/cli/launcher/multinode_runner.py @@ -7,8 +7,13 @@ from .hostinfo import HostInfo, HostInfoList -def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection, - send_conn: mp_connection.Connection, env: dict) -> None: +def run_on_host( + hostinfo: HostInfo, + workdir: str, + recv_conn: mp_connection.Connection, + send_conn: mp_connection.Connection, + env: dict, +) -> None: """ Use fabric connection to execute command on local or remote hosts. @@ -22,14 +27,14 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne fab_conn = fabric.Connection(hostinfo.hostname, port=hostinfo.port) finish = False - env_msg = ' '.join([f'{k}=\"{v}\"' for k, v in env.items()]) + env_msg = " ".join([f'{k}="{v}"' for k, v in env.items()]) # keep listening until exit while not finish: # receive cmd cmds = recv_conn.recv() - if cmds == 'exit': + if cmds == "exit": # exit from the loop finish = True break @@ -46,12 +51,12 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne else: # execute on the remote machine fab_conn.run(cmds, hide=False) - send_conn.send('success') + send_conn.send("success") except Exception as e: click.echo( f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}" ) - send_conn.send('failure') + send_conn.send("failure") # shutdown send_conn.send("finish") @@ -96,8 +101,7 @@ def send(self, hostinfo: HostInfo, cmd: str) -> None: cmd (str): the command to execute """ - assert hostinfo.hostname in self.master_send_conns, \ - f'{hostinfo} is not found in the current connections' + assert hostinfo.hostname in self.master_send_conns, f"{hostinfo} is not found in the current connections" conn = self.master_send_conns[hostinfo.hostname] conn.send(cmd) @@ -107,7 +111,7 @@ def stop_all(self) -> None: """ for hostname, conn in self.master_send_conns.items(): - conn.send('exit') + conn.send("exit") def recv_from_all(self) -> dict: """ diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index d2d02811ac9d..7ca8ee90386c 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -12,7 +12,7 @@ from .multinode_runner import MultiNodeRunner # Constants that define our syntax -NODE_SEP = ',' +NODE_SEP = "," def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: @@ -34,12 +34,12 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: click.echo(f"Error: Unable to find the hostfile, no such file: {hostfile_path}") exit() - with open(hostfile_path, 'r') as fd: + with open(hostfile_path, "r") as fd: device_pool = HostInfoList() for line in fd.readlines(): line = line.strip() - if line == '': + if line == "": # skip empty lines continue @@ -56,7 +56,7 @@ def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList: def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str=None) -> HostInfoList: - '''Parse an inclusion or exclusion string and filter a hostfile dictionary. + """Parse an inclusion or exclusion string and filter a hostfile dictionary. Examples: include_str="worker-0,worker-1" will execute jobs only on worker-0 and worker-1. @@ -69,7 +69,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str Returns: filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion - ''' + """ # Ensure include/exclude are mutually exclusive if include_str and exclude_str: @@ -136,16 +136,16 @@ def _arg_dict_to_list(arg_dict): for k, v in arg_dict.items(): if v: - ret.append(f'--{k}={v}') + ret.append(f"--{k}={v}") else: - ret.append(f'--{k}') + ret.append(f"--{k}") return ret if extra_launch_args: extra_launch_args_dict = dict() - for arg in extra_launch_args.split(','): - if '=' in arg: - k, v = arg.split('=') + for arg in extra_launch_args.split(","): + if "=" in arg: + k, v = arg.split("=") extra_launch_args_dict[k] = v else: extra_launch_args_dict[arg] = None @@ -158,9 +158,14 @@ def _arg_dict_to_list(arg_dict): if torch_version.minor < 9: cmd = [ - sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", - f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}", - f"--node_rank={node_rank}" + sys.executable, + "-m", + "torch.distributed.launch", + f"--nproc_per_node={nproc_per_node}", + f"--master_addr={master_addr}", + f"--master_port={master_port}", + f"--nnodes={num_nodes}", + f"--node_rank={node_rank}", ] else: # extra launch args for torch distributed launcher with torch >= 1.9 @@ -174,17 +179,24 @@ def _arg_dict_to_list(arg_dict): if torch_version.minor < 10: cmd = [ - sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}", - f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc_per_node={nproc_per_node}", + f"--nnodes={num_nodes}", + f"--node_rank={node_rank}", ] else: cmd = [ - "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" + "torchrun", + f"--nproc_per_node={nproc_per_node}", + f"--nnodes={num_nodes}", + f"--node_rank={node_rank}", ] cmd += _arg_dict_to_list(default_torchrun_rdzv_args) cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args - cmd = ' '.join(cmd) + cmd = " ".join(cmd) return cmd @@ -248,18 +260,18 @@ def launch_multi_processes(args: Config) -> None: # run on local node if not hosts or hostfile is given # add local node to host info list active_device_pool = HostInfoList() - localhost_info = HostInfo(hostname='127.0.0.1', port=args.ssh_port) + localhost_info = HostInfo(hostname="127.0.0.1", port=args.ssh_port) active_device_pool.append(localhost_info) # launch distributed processes runner = MultiNodeRunner() - curr_path = os.path.abspath('.') + curr_path = os.path.abspath(".") # collect current path env env = dict() for k, v in os.environ.items(): # do not support multi-line env var - if v and '\n' not in v: + if v and "\n" not in v: env[k] = v # establish remote connection @@ -271,14 +283,16 @@ def launch_multi_processes(args: Config) -> None: # execute distributed launching command for node_id, hostinfo in enumerate(active_device_pool): - cmd = get_launch_command(master_addr=args.master_addr, - master_port=args.master_port, - nproc_per_node=args.nproc_per_node, - user_script=args.user_script, - user_args=args.user_args, - node_rank=node_id, - num_nodes=len(active_device_pool), - extra_launch_args=args.extra_launch_args) + cmd = get_launch_command( + master_addr=args.master_addr, + master_port=args.master_port, + nproc_per_node=args.nproc_per_node, + user_script=args.user_script, + user_args=args.user_args, + node_rank=node_id, + num_nodes=len(active_device_pool), + extra_launch_args=args.extra_launch_args, + ) runner.send(hostinfo=hostinfo, cmd=cmd) # start training diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py index 44f571ca2501..b8176feb647b 100644 --- a/colossalai/cluster/__init__.py +++ b/colossalai/cluster/__init__.py @@ -3,4 +3,4 @@ from .process_group_manager import ProcessGroupManager from .process_group_mesh import ProcessGroupMesh -__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager', 'ProcessGroupMesh'] +__all__ = ["DistCoordinator", "ProcessGroupManager", "DeviceMeshManager", "ProcessGroupMesh"] diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py index 8754baa19792..e35aca5f4d7e 100644 --- a/colossalai/cluster/device_mesh_manager.py +++ b/colossalai/cluster/device_mesh_manager.py @@ -10,13 +10,14 @@ @dataclass class DeviceMeshInfo: - ''' + """ This class is used to store the information used to initialize the device mesh. Args: physical_ids (List[int]): The physical ids of the current booster. For example, if we have the last 4 GPUs on a 8-devices cluster, then the physical ids should be [4, 5, 6, 7]. mesh_shapes (List[Union[torch.Size, List[int], Tuple[int]]]): The shape of the mesh. For example, if we have 4 GPUs and we want to use 2D mesh with mesh shape [2, 2], then the mesh shape should be [2, 2]. - ''' + """ + physical_ids: List[int] mesh_shape: Union[torch.Size, List[int], Tuple[int]] = None @@ -24,16 +25,18 @@ def __post_init__(self): if self.mesh_shape is not None: world_size = len(self.physical_ids) mesh_shape_numel = torch.Size(self.mesh_shape).numel() - assert world_size == mesh_shape_numel, f'the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}' + assert ( + world_size == mesh_shape_numel + ), f"the numel of mesh_shape should be equal to world size, but got {world_size} != {mesh_shape_numel}" def initialize_device_mesh(device_mesh_info: DeviceMeshInfo): - ''' + """ This method is used to initialize the device mesh. Args: device_mesh_info (DeviceMeshInfo): The information used to initialize device mesh. - ''' + """ # parse the device mesh info physical_devices = device_mesh_info.physical_ids physical_mesh = torch.tensor(physical_devices) @@ -67,13 +70,13 @@ def create_device_mesh(self, name, device_mesh_info: DeviceMeshInfo) -> DeviceMe Args: name (str): name of the device mesh device_mesh_info (DeviceMeshInfo): the information used to initialize the device mesh - """ + """ if name not in self.device_mesh_store: device_mesh = initialize_device_mesh(device_mesh_info) self.device_mesh_store[name] = device_mesh return device_mesh else: - raise ValueError(f'Device mesh {name} already exists.') + raise ValueError(f"Device mesh {name} already exists.") def get(self, name: str) -> DeviceMesh: """ @@ -88,7 +91,7 @@ def get(self, name: str) -> DeviceMesh: if name in self.device_mesh_store: return self.device_mesh_store[name] else: - raise ValueError(f'Device mesh {name} does not exist.') + raise ValueError(f"Device mesh {name} does not exist.") def destroy(self, name: str) -> None: """ @@ -103,7 +106,7 @@ def destroy(self, name: str) -> None: dist.destroy_process_group(pg) del self.device_mesh_store[name] else: - raise ValueError(f'Device mesh {name} does not exist.') + raise ValueError(f"Device mesh {name} does not exist.") def destroy_all(self): """ diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py index 3ee364ec3364..5b66e88717ba 100644 --- a/colossalai/cluster/dist_coordinator.py +++ b/colossalai/cluster/dist_coordinator.py @@ -36,12 +36,13 @@ class in the whole program. """ def __init__(self): - assert dist.is_initialized( - ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.' + assert ( + dist.is_initialized() + ), "Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first." self._rank = dist.get_rank() self._world_size = dist.get_world_size() # this is often passed by launchers such as torchrun - self._local_rank = os.environ.get('LOCAL_RANK', -1) + self._local_rank = os.environ.get("LOCAL_RANK", -1) @property def rank(self) -> int: @@ -59,7 +60,9 @@ def _assert_local_rank_set(self): """ Assert that the local rank is set. This is often passed by launchers such as torchrun. """ - assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.' + assert ( + self.local_rank >= 0 + ), "The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process." def is_master(self, process_group: ProcessGroup = None) -> bool: """ @@ -183,7 +186,6 @@ def on_master_only(self, process_group: ProcessGroup = None): # define an inner function def decorator(func): - @functools.wraps(func) def wrapper(*args, **kwargs): if is_master: diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py index e52661846f3e..68106b503126 100644 --- a/colossalai/cluster/process_group_manager.py +++ b/colossalai/cluster/process_group_manager.py @@ -19,7 +19,7 @@ class ProcessGroupManager: def __init__(self): self.pg_store = dict() - def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup: + def create_process_group(self, name: str, ranks: List[int], backend: str = "nccl") -> ProcessGroup: """ Get a process group by name. If the process group does not exist, it will be created. @@ -36,7 +36,7 @@ def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl self.pg_store[name] = pg return pg else: - raise ValueError(f'Process group {name} already exists.') + raise ValueError(f"Process group {name} already exists.") def get(self, name: str) -> ProcessGroup: """ @@ -51,7 +51,7 @@ def get(self, name: str) -> ProcessGroup: if name in self.pg_store: return self.pg_store[name] else: - raise ValueError(f'Process group {name} does not exist.') + raise ValueError(f"Process group {name} does not exist.") def destroy(self, name: str) -> None: """ @@ -64,7 +64,7 @@ def destroy(self, name: str) -> None: dist.destroy_process_group(self.pg_store[name]) del self.pg_store[name] else: - raise ValueError(f'Process group {name} does not exist.') + raise ValueError(f"Process group {name} does not exist.") def destroy_all(self) -> None: """ diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 623160003767..3885bc962561 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -94,7 +94,7 @@ def unravel(rank: int, shape: Tuple[int, ...]) -> Tuple[int, ...]: return np.unravel_index(rank, shape) @staticmethod - def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int: + def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = "raise") -> int: """Convert a coordinate to a rank. mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html. with wrap, index out of range would be wrapped around. @@ -141,8 +141,9 @@ def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: return list(self._group_to_ranks[group]) @staticmethod - def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int, - indices_at_axis: List[int]) -> List[Tuple[int, ...]]: + def get_coords_along_axis( + base_coord: Tuple[int, ...], axis: int, indices_at_axis: List[int] + ) -> List[Tuple[int, ...]]: """Get coordinates along the given axis. Args: @@ -155,13 +156,12 @@ def get_coords_along_axis(base_coord: Tuple[int, ...], axis: int, """ coords_in_group = [] for idx in indices_at_axis: - coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1:]) + coords_in_group.append(base_coord[:axis] + (idx,) + base_coord[axis + 1 :]) return coords_in_group - def create_group_along_axis(self, - axis: int, - indices_at_axis: Optional[List[int]] = None, - backend: Optional[str] = None) -> ProcessGroup: + def create_group_along_axis( + self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + ) -> ProcessGroup: """Create all process groups along the given axis, and return the one which the current process belongs to. Args: @@ -186,10 +186,9 @@ def create_group_along_axis(self, target_group = group return target_group - def get_group_along_axis(self, - axis: int, - indices_at_axis: Optional[List[int]] = None, - backend: Optional[str] = None) -> ProcessGroup: + def get_group_along_axis( + self, axis: int, indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + ) -> ProcessGroup: """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. Args: diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py index eb6d5d05a008..ab57301bb910 100644 --- a/colossalai/context/__init__.py +++ b/colossalai/context/__init__.py @@ -3,6 +3,6 @@ # from .moe_context import MOE_CONTEXT __all__ = [ - 'Config', - 'ConfigException', + "Config", + "ConfigException", ] diff --git a/colossalai/context/config.py b/colossalai/context/config.py index 8903707708df..05a2e4bf044a 100644 --- a/colossalai/context/config.py +++ b/colossalai/context/config.py @@ -5,6 +5,7 @@ import sys from importlib.machinery import SourceFileLoader from pathlib import Path + from colossalai.logging import get_dist_logger @@ -41,7 +42,7 @@ def _add_item(self, key, value): self.__setattr__(key, value) def update(self, config): - assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.' + assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects." for k, v in config.items(): self._add_item(k, v) return self @@ -66,11 +67,11 @@ def from_file(filename: str): elif isinstance(filename, Path): filepath = filename.absolute() - assert filepath.exists(), f'{filename} is not found, please check your configuration path' + assert filepath.exists(), f"{filename} is not found, please check your configuration path" # check extension extension = filepath.suffix - assert extension == '.py', 'only .py files are supported' + assert extension == ".py", "only .py files are supported" # import the config as module remove_path = False @@ -86,13 +87,13 @@ def from_file(filename: str): config = Config() for k, v in module.__dict__.items(): - if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v): + if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v): continue else: config._add_item(k, v) logger = get_dist_logger() - logger.debug('variables which starts with __, is a module or class declaration are omitted in config file') + logger.debug("variables which starts with __, is a module or class declaration are omitted in config file") # remove module del sys.modules[module_name] diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index b6e3b52017b2..066dfc7222e1 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -9,14 +9,13 @@ def _check_sanity(): from colossalai.legacy.core import global_context as gpc + if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: - raise NotImplementedError("Moe is not compatible with tensor or " - "pipeline parallel at present.") + raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.") class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups. - """ + """Moe parallelism information, storing parallel sizes and groups.""" def __init__(self, ep_size: int, dp_size: int): _check_sanity() @@ -61,9 +60,11 @@ def setup(self, seed: int, use_kernel_optim: bool = True): self.world_size = dist.get_world_size() from colossalai.legacy.core import global_context as gpc - self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) - assert self.world_size % self.max_ep_size == 0, \ - "Maximum expert parallel size must be a factor of the number of GPUs" + + self.max_ep_size = gpc.config.get("max_ep_size", self.world_size) + assert ( + self.world_size % self.max_ep_size == 0 + ), "Maximum expert parallel size must be a factor of the number of GPUs" self.min_dp_size = self.world_size // self.max_ep_size # Enabling kernel optimization may raise error in some cases @@ -71,6 +72,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True): self.use_kernel_optim = use_kernel_optim from .random import moe_set_seed + moe_set_seed(seed) self.has_setup = True @@ -88,11 +90,13 @@ def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: number of local experts, the MoeParallelInfo of the current ep_size """ - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ - " is not a multiple of ep size or vice versa." + assert gt_flag or lt_flag, ( + "Automatic experts placement dose not not support expert number" + " is not a multiple of ep size or vice versa." + ) # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, # there are multiple experts in each GPU and each GPU has different experts diff --git a/colossalai/context/singleton_meta.py b/colossalai/context/singleton_meta.py index 8ca335119d52..3088b0dffaac 100644 --- a/colossalai/context/singleton_meta.py +++ b/colossalai/context/singleton_meta.py @@ -16,6 +16,7 @@ def __call__(cls, *args, **kwargs): instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance else: - assert len(args) == 0 and len( - kwargs) == 0, f'{cls.__name__} is a singleton class and a instance has been created.' + assert ( + len(args) == 0 and len(kwargs) == 0 + ), f"{cls.__name__} is a singleton class and a instance has been created." return cls._instances[cls] diff --git a/colossalai/device/__init__.py b/colossalai/device/__init__.py index 689189998c3f..34a7d2526fda 100644 --- a/colossalai/device/__init__.py +++ b/colossalai/device/__init__.py @@ -1,4 +1,4 @@ from .alpha_beta_profiler import AlphaBetaProfiler from .calc_pipeline_strategy import alpa_dp -__all__ = ['AlphaBetaProfiler', 'alpa_dp'] +__all__ = ["AlphaBetaProfiler", "alpa_dp"] diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py index f4e6cfffbcdf..88520b2a14d0 100644 --- a/colossalai/device/alpha_beta_profiler.py +++ b/colossalai/device/alpha_beta_profiler.py @@ -13,7 +13,7 @@ class AlphaBetaProfiler: - ''' + """ Profile alpha and beta value for a given device list. Usage: @@ -27,17 +27,19 @@ class AlphaBetaProfiler: (1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12), (1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11), (4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)} - ''' - - def __init__(self, - physical_devices: List[int], - alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, - ctype: str = 'a', - warmup: int = 5, - repeat: int = 25, - latency_iters: int = 5, - homogeneous_tolerance: float = 0.1): - ''' + """ + + def __init__( + self, + physical_devices: List[int], + alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None, + ctype: str = "a", + warmup: int = 5, + repeat: int = 25, + latency_iters: int = 5, + homogeneous_tolerance: float = 0.1, + ): + """ Args: physical_devices: A list of device id, each element inside it is the global rank of that device. alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs. @@ -45,7 +47,7 @@ def __init__(self, warmup: Number of warmup iterations. repeat: Number of iterations to measure. latency_iters: Number of iterations to measure latency. - ''' + """ self.physical_devices = physical_devices self.ctype = ctype self.world_size = len(physical_devices) @@ -123,7 +125,7 @@ def _profile(self, process_group, pg_handler, nbytes): return (None, None) def profile_latency(self, process_group, pg_handler): - ''' + """ This function is used to profile the latency of the given process group with a series of bytes. Args: @@ -132,7 +134,7 @@ def profile_latency(self, process_group, pg_handler): Returns: latency: None if the latency is not measured, otherwise the median of the latency_list. - ''' + """ latency_list = [] for i in range(self.latency_iters): nbytes = int(BYTE << i) @@ -148,26 +150,26 @@ def profile_latency(self, process_group, pg_handler): return latency def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)): - ''' + """ This function is used to profile the bandwidth of the given process group. Args: process_group: A tuple of global rank of the process group. pg_handler: The handler of the process group. - ''' + """ (_, bandwidth) = self._profile(process_group, pg_handler, maxbytes) return bandwidth def profile_ab(self): - ''' + """ This method is used to profiling the alpha and beta value for a given device list. Returns: alpha_beta_dict: A dict which maps process group to its alpha and beta value. - ''' + """ alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {} rank = dist.get_rank() - global_pg_handler = dist.new_group(self.physical_devices) + dist.new_group(self.physical_devices) def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): assert rank in process_group @@ -208,7 +210,7 @@ def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup): return alpha_beta_dict def search_best_logical_mesh(self): - ''' + """ This method is used to search the best logical mesh for the given device list. The best logical mesh is searched in following steps: @@ -232,19 +234,19 @@ def search_best_logical_mesh(self): >>> best_logical_mesh = profiler.search_best_logical_mesh() >>> print(best_logical_mesh) [[0, 1], [2, 3]] - ''' + """ def _power_of_two(integer): return integer & (integer - 1) == 0 def _detect_homogeneous_device(alpha_beta_dict): - ''' + """ This function is used to detect whether the devices in the alpha_beta_dict are homogeneous. Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)] * base_beta. - ''' + """ homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {} for process_group, (_, beta) in alpha_beta_dict.items(): if homogeneous_device_dict is None: @@ -254,7 +256,8 @@ def _detect_homogeneous_device(alpha_beta_dict): match_beta = None for beta_value in homogeneous_device_dict.keys(): if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * ( - 1 - self.homogeneous_tolerance): + 1 - self.homogeneous_tolerance + ): match_beta = beta_value break @@ -267,9 +270,9 @@ def _detect_homogeneous_device(alpha_beta_dict): return homogeneous_device_dict def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): - ''' + """ This function is used to check whether the homogeneous_group contains all physical devices. - ''' + """ flatten_mesh = [] for process_group in homogeneous_group: flatten_mesh.extend(process_group) @@ -277,9 +280,9 @@ def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]): return len(non_duplicated_flatten_mesh) == len(self.physical_devices) def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): - ''' + """ This function is used to construct the largest ring in the homogeneous_group for each rank. - ''' + """ # Construct the ring ring = [] ranks_in_ring = [] @@ -300,7 +303,9 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): check_rank = check_rank_list.pop() for process_group in homogeneous_group: if check_rank in process_group: - rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1] + rank_to_append = ( + process_group[0] if process_group[1] == check_rank else process_group[1] + ) if rank_to_append not in ring_for_rank: stable_status = False rank_to_check_list.append(rank_to_append) @@ -314,7 +319,7 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): assert _power_of_two(self.world_size) power_of_two = int(math.log2(self.world_size)) median = power_of_two // 2 - balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median)) + balanced_logical_mesh_shape = (2**median, 2 ** (power_of_two - median)) row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1] balanced_logical_mesh = [] for row_index in range(row_size): @@ -348,7 +353,7 @@ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]): return best_logical_mesh def extract_alpha_beta_for_device_mesh(self): - ''' + """ Extract the mesh_alpha list and mesh_beta list based on the best logical mesh, which will be used to initialize the device mesh. @@ -360,7 +365,7 @@ def extract_alpha_beta_for_device_mesh(self): [2.5917552411556242e-05, 0.00010312341153621673] >>> print(mesh_beta) [5.875573704655635e-11, 4.7361584445959614e-12] - ''' + """ best_logical_mesh = self.search_best_logical_mesh() first_axis = [row[0] for row in best_logical_mesh] diff --git a/colossalai/device/calc_pipeline_strategy.py b/colossalai/device/calc_pipeline_strategy.py index 4ab72dfe60f0..72d432701ada 100644 --- a/colossalai/device/calc_pipeline_strategy.py +++ b/colossalai/device/calc_pipeline_strategy.py @@ -10,8 +10,10 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): while i <= num_devices_per_host: i *= 2 p += 1 - assert pow(2, p) == num_devices_per_host, ("Only supports the cases where num_devices_per_host is power of two, " - f"while now num_devices_per_host = {num_devices_per_host}") + assert pow(2, p) == num_devices_per_host, ( + "Only supports the cases where num_devices_per_host is power of two, " + f"while now num_devices_per_host = {num_devices_per_host}" + ) if mode == "alpa": for i in range(p + 1): submesh_choices.append((1, pow(2, i))) @@ -24,18 +26,19 @@ def get_submesh_choices(num_hosts, num_devices_per_host, mode="new"): return submesh_choices -def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, - best_configs): +def alpa_dp_impl( + num_layers, num_devices, num_microbatches, submesh_choices, compute_cost, max_stage_cost, best_configs +): """Implementation of Alpa DP for pipeline strategy - Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf + Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf - Arguments: - num_layers: K - num_devices: N*M - num_microbatches: B - submesh_choices: List[(n_i,m_i)] - compute_cost: t_intra - """ + Arguments: + num_layers: K + num_devices: N*M + num_microbatches: B + submesh_choices: List[(n_i,m_i)] + compute_cost: t_intra + """ # For f, layer ID start from 0 # f[#pipeline stages, layer id that is currently being considered, number of devices used] f = np.full((num_layers + 1, num_layers + 1, num_devices + 1), np.inf, dtype=np.float32) @@ -54,7 +57,7 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com for i in range(num_layers, k, -1): stage_cost = compute_cost[k, i, m] new_cost = f[s - 1, k, d - n_submesh_devices] + stage_cost - if (stage_cost <= max_stage_cost and new_cost < f[s, k, d]): + if stage_cost <= max_stage_cost and new_cost < f[s, k, d]: f[s, k, d] = new_cost f_stage_max[s, k, d] = max(stage_cost, f_stage_max[s - 1, i, d - n_submesh_devices]) f_argmin[s, k, d] = (i, m, best_configs[k, i, m]) @@ -75,34 +78,34 @@ def alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, com res = [] while current_s > 0 and current_layer < num_layers and current_devices > 0: - next_start_layer, submesh_choice, autosharding_choice = (f_argmin[current_s, current_layer, current_devices]) + next_start_layer, submesh_choice, autosharding_choice = f_argmin[current_s, current_layer, current_devices] assert next_start_layer != -1 and current_devices != -1 res.append(((current_layer, next_start_layer), submesh_choice, autosharding_choice)) current_s -= 1 current_layer = next_start_layer current_devices -= np.prod(np.array(submesh_choices[submesh_choice])) - assert (current_s == 0 and current_layer == num_layers and current_devices == 0) + assert current_s == 0 and current_layer == num_layers and current_devices == 0 return total_cost, res -def alpa_dp(num_layers, - num_devices, - num_microbatches, - submesh_choices, - num_autosharding_configs, - compute_cost, - gap=1e-6): +def alpa_dp( + num_layers, num_devices, num_microbatches, submesh_choices, num_autosharding_configs, compute_cost, gap=1e-6 +): """Alpa auto stage dynamic programming. - Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py + Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py Arguments: submesh_choices: List[(int,int)] num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh) compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs) """ - assert np.shape(compute_cost) == (num_layers, num_layers, len(submesh_choices), - num_autosharding_configs), "Cost shape wrong." + assert np.shape(compute_cost) == ( + num_layers, + num_layers, + len(submesh_choices), + num_autosharding_configs, + ), "Cost shape wrong." all_possible_stage_costs = np.sort(np.unique(compute_cost)) best_cost = np.inf best_solution = None @@ -117,8 +120,9 @@ def alpa_dp(num_layers, break if max_stage_cost - last_max_stage_cost < gap: continue - cost, solution = alpa_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, - max_stage_cost, best_configs) + cost, solution = alpa_dp_impl( + num_layers, num_devices, num_microbatches, submesh_choices, best_compute_cost, max_stage_cost, best_configs + ) if cost < best_cost: best_cost = cost best_solution = solution diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index f41af1161be1..72f199203a9d 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -40,14 +40,16 @@ class DeviceMesh: _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} - def __init__(self, - physical_mesh_id: torch.Tensor, - mesh_shape: torch.Size = None, - logical_mesh_id: torch.Tensor = None, - mesh_alpha: List[float] = None, - mesh_beta: List[float] = None, - init_process_group: bool = False, - device: str = 'cuda'): + def __init__( + self, + physical_mesh_id: torch.Tensor, + mesh_shape: torch.Size = None, + logical_mesh_id: torch.Tensor = None, + mesh_alpha: List[float] = None, + mesh_beta: List[float] = None, + init_process_group: bool = False, + device: str = "cuda", + ): # ============================ # Physical & Logical Mesh IDs # ============================ @@ -57,9 +59,10 @@ def __init__(self, # logical mesh ids can be obtained via two ways # 1. provide physical mesh id and provide mesh shape # 2. directly supply the logical mesh id - assert mesh_shape is None or logical_mesh_id is None, \ - "Only one of mesh_shape and logical_mesh_id can be specified." \ + assert mesh_shape is None or logical_mesh_id is None, ( + "Only one of mesh_shape and logical_mesh_id can be specified." "Logical mesh IDs are obtained from either mesh_shape + physical_mesh_id or directly from the user-supplied logical_mesh_id" + ) if logical_mesh_id is None: self._mesh_shape = mesh_shape @@ -71,12 +74,15 @@ def __init__(self, # ensure two things: # 1. logical and physical mesh IDs should contain the same elements # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed - assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ - "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." - assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ - "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again." - assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ - "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." + assert torch.equal( + torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id) + ), "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." + assert ( + torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel() + ), "Found duplicate IDs in the physical_mesh_id and this is not allowed, please check your physical_mesh_id again." + assert ( + torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel() + ), "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." # =============================================== # coefficient for alpha-beta communication model @@ -92,8 +98,9 @@ def __init__(self, self.mesh_beta = tuple(mesh_beta) # ensure the alpha and beta have the same shape - assert len(self.mesh_alpha) == len(self.mesh_beta), \ - "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." + assert len(self.mesh_alpha) == len( + self.mesh_beta + ), "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." # ========================= # Device for Process Group @@ -109,8 +116,9 @@ def __init__(self, # : [ , , , ...] # } self._global_to_local_rank_mapping = dict() - self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping, - tensor=self.logical_mesh_id) + self._init_global_to_logical_rank_mapping( + mapping=self._global_to_local_rank_mapping, tensor=self.logical_mesh_id + ) # create process group self._process_group_dict = {} @@ -194,8 +202,9 @@ def _get_device_by_backend(process_group): device_list = [_get_device_by_backend(pg) for pg in process_group] # make sure all devices are the same - assert all([device == device_list[0] for device in device_list]), \ - "All devices should be the same, please check your input process groups are created with the same distributed backend." + assert all( + [device == device_list[0] for device in device_list] + ), "All devices should be the same, please check your input process groups are created with the same distributed backend." # create a fake physical mesh id # as we only get the process group associated with the current process, @@ -270,7 +279,7 @@ def __deepcopy__(self, memo) -> "DeviceMesh": result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k != '_process_group_dict': + if k != "_process_group_dict": setattr(result, k, __import__("copy").deepcopy(v, memo)) else: # process group cannot be copied @@ -278,10 +287,9 @@ def __deepcopy__(self, memo) -> "DeviceMesh": setattr(result, k, v) return result - def _init_global_to_logical_rank_mapping(self, - mapping: Dict, - tensor: torch.Tensor, - index_list: List[int] = []) -> Dict[int, List[int]]: + def _init_global_to_logical_rank_mapping( + self, mapping: Dict, tensor: torch.Tensor, index_list: List[int] = [] + ) -> Dict[int, List[int]]: """ Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. @@ -311,15 +319,19 @@ def _init_global_to_logical_rank_mapping(self, self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) def init_logical_process_group(self): - ''' + """ This method is used to initialize the logical process groups which will be used in communications among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. - ''' + """ # sanity check - assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" - assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" + assert ( + dist.is_initialized + ), "The torch.distributed should be initialized before calling init_logical_process_group" + assert ( + not self._is_initialized + ), "The logical process group has been initialized, do not call init_logical_process_group twice" # update the global rank of the current process self._global_rank_of_current_process = dist.get_rank() @@ -389,7 +401,7 @@ def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[i return local_ranks def _collate_global_ranks_in_same_process_group(self, global_rank): - ''' + """ Give a global rank and return all global ranks involved in its associated process group in each axis. Example: @@ -414,7 +426,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank): 0: [0, 4, 8, 12], 1: [0, 1, 2, 3] # } - ''' + """ # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping # for self._global_to_local_rank_mapping # the key is the global rank @@ -437,7 +449,6 @@ def _collate_global_ranks_in_same_process_group(self, global_rank): # in the same process group in the given axis # the _local_rank refers to the local rank of the current process for _local_rank in range(self.logical_mesh_id.shape[dim]): - # if this dimension is not initialized yet, # initialize it with an empty array if dim not in processes_in_the_same_process_group: @@ -478,29 +489,37 @@ def flatten(self): flatten_mesh_shape_size = len(self._mesh_shape) flatten_mesh_shape = [self.num_devices] - return DeviceMesh(self._physical_mesh_id, - tuple(flatten_mesh_shape), - mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), - init_process_group=self._init_process_group) + return DeviceMesh( + self._physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self._init_process_group, + ) def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + - 0.1) + return self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.1 def all_reduce_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + - 0.01) + return ( + self.mesh_alpha[mesh_dim] + + self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes + + 0.01 + ) def reduce_scatter_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + - 0.001) + return ( + self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes + 0.001 + ) def all_to_all_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] penalty_factor = num_devices / 2.0 - return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * - (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) + return ( + self.mesh_alpha[mesh_dim] + + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + + 0.001 + ) diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py index 0444a4816273..4d40d5badfd0 100644 --- a/colossalai/fx/_compatibility.py +++ b/colossalai/fx/_compatibility.py @@ -2,16 +2,14 @@ import torch -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) if TORCH_MAJOR == 1 and TORCH_MINOR < 12: META_COMPATIBILITY = False elif TORCH_MAJOR == 1 and TORCH_MINOR == 12: - from . import _meta_regist_12 META_COMPATIBILITY = True elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: - from . import _meta_regist_13 META_COMPATIBILITY = True elif TORCH_MAJOR == 2: META_COMPATIBILITY = True @@ -36,7 +34,7 @@ def decorator(func): else: def wrapper(*args, **kwargs): - raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}') + raise RuntimeError(f"Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}") return wrapper diff --git a/colossalai/fx/_meta_regist_12.py b/colossalai/fx/_meta_regist_12.py index 52e8d63ae543..63f88682e85a 100644 --- a/colossalai/fx/_meta_regist_12.py +++ b/colossalai/fx/_meta_regist_12.py @@ -3,7 +3,7 @@ # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # for more meta_registrations -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Union import torch from torch.utils._pytree import tree_map @@ -16,13 +16,11 @@ def register_meta(op, register_dispatcher=True): - def wrapper(f): - def add_func(op): meta_table[op] = f if register_dispatcher: - name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__ try: meta_lib.impl(name, f) except: @@ -48,7 +46,6 @@ def meta_conv( output_padding: List[int], groups: int, ): - def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output @@ -125,7 +122,8 @@ def calc_conv_nd_return_shape( kernel_size[i], stride[i], output_padding_list[i], - )) + ) + ) else: ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) return ret_shape @@ -159,22 +157,42 @@ def pick_memory_format(): shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) mem_fmt = pick_memory_format() - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out @register_meta(aten._convolution.default) -def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], - padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, - *extra_args): +def meta_conv_1( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + *extra_args, +): out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) return out @register_meta(aten.convolution_backward.default) -def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, - padding, dilation, transposed, output_padding, groups, output_mask): - return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') +def meta_conv_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta") # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp @@ -208,7 +226,6 @@ def meta_cuda_rnn( batch_sizes, dropout_state, ): - is_input_packed = len(batch_sizes) != 0 if is_input_packed: seq_length = len(batch_sizes) @@ -224,8 +241,11 @@ def meta_cuda_rnn( if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: - out_shape = ([mini_batch, seq_length, out_size * - num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) + out_shape = ( + [mini_batch, seq_length, out_size * num_directions] + if batch_first + else [seq_length, mini_batch, out_size * num_directions] + ) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] @@ -242,18 +262,20 @@ def meta_cuda_rnn( # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp @register_meta(aten._cudnn_rnn_backward.default) -def meta_cudnn_rnn_backward(input: torch.Tensor, - weight: torch.Tensor, - weight_stride0: int, - hx: torch.Tensor, - cx: Optional[torch.Tensor] = None, - *args, - **kwargs): +def meta_cudnn_rnn_backward( + input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs, +): print(input, weight, hx, cx) grad_input = torch.empty_like(input) grad_weight = torch.empty_like(weight) grad_hx = torch.empty_like(hx) - grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') + grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta") return grad_input, grad_weight, grad_hx, grad_cx @@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini n_input = input.size(1) output = torch.empty_like(input) - running_mean = torch.empty((n_input), device='meta') - running_var = torch.empty((n_input), device='meta') + running_mean = torch.empty((n_input), device="meta") + running_var = torch.empty((n_input), device="meta") return output, running_mean, running_var # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp @register_meta(aten.native_batch_norm_backward.default) -def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, - save_invstd, train, eps, output_mask): +def meta_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, +): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) @@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, n_input = input.size(1) output = torch.empty_like(input) - running_mean = torch.empty((n_input), device='meta') - running_var = torch.empty((n_input), device='meta') - reserve = torch.empty((0), dtype=torch.uint8, device='meta') + running_mean = torch.empty((n_input), device="meta") + running_var = torch.empty((n_input), device="meta") + reserve = torch.empty((0), dtype=torch.uint8, device="meta") return output, running_mean, running_var, reserve @@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, # in training mode (evaluation mode batchnorm has a different algorithm), # which is why this doesn't accept a 'training' parameter. @register_meta(aten.cudnn_batch_norm_backward.default) -def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, - save_mean, save_invstd, eps, reserve): +def meta_cudnn_bn_backward( + dY: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + running_mean, + running_var, + save_mean, + save_invstd, + eps, + reserve, +): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) @@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): n_input = input.size(1) output = torch.empty_like(input) - running_mean = torch.empty((bs, n_input, 1), device='meta') - running_var = torch.empty((bs, n_input, 1), device='meta') + running_mean = torch.empty((bs, n_input, 1), device="meta") + running_var = torch.empty((bs, n_input, 1), device="meta") return output, running_mean, running_var # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp @register_meta(aten.native_layer_norm_backward.default) -def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, - grad_input_mask): +def meta_ln_backward( + dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask +): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(bias) @@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices): result: List[Optional[torch.Tensor]] = [] for i, index in enumerate(indices): if index is not None: - assert index.dtype in [torch.long, torch.int8, torch.bool],\ - "tensors used as indices must be long, byte or bool tensors" + assert index.dtype in [ + torch.long, + torch.int8, + torch.bool, + ], "tensors used as indices must be long, byte or bool tensors" if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" for j in range(index.ndim): - assert index.shape[j] == self.shape[ - k + - j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + assert ( + index.shape[j] == self.shape[k + j] + ), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" result.append(nonzero.select(1, j)) else: result.append(index) @@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices): # ============================== Embedding ========================================= # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp @register_meta(aten.embedding_dense_backward.default) -def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, - scale_grad_by_freq): - return torch.empty((num_weights, grad_output.size(-1)), - dtype=grad_output.dtype, - device=grad_output.device, - layout=grad_output.layout) +def meta_embedding_dense_backward( + grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq +): + return torch.empty( + (num_weights, grad_output.size(-1)), + dtype=grad_output.dtype, + device=grad_output.device, + layout=grad_output.layout, + ) # ============================== Dropout =========================================== diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 33b164800262..dfb5754d71c1 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple import torch @@ -18,6 +18,7 @@ magic_methods, ) from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = True except: from torch.fx.graph import ( @@ -32,12 +33,13 @@ magic_methods, ) from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + CODEGEN_AVAILABLE = False if CODEGEN_AVAILABLE: - __all__ = ['ActivationCheckpointCodeGen'] + __all__ = ["ActivationCheckpointCodeGen"] else: - __all__ = ['python_code_with_activation_checkpoint'] + __all__ = ["python_code_with_activation_checkpoint"] def _gen_saved_tensors_hooks(): @@ -125,15 +127,14 @@ def _find_ckpt_regions(nodes: List[Node]): Find the checkpoint regions given a list of consecutive nodes. The outputs will be list of tuples, each tuple is in the form of (start_index, end_index). """ - ckpt_nodes = [] ckpt_regions = [] start = -1 end = -1 current_region = None for idx, node in enumerate(nodes): - if 'activation_checkpoint' in node.meta: - act_ckpt_label = node.meta['activation_checkpoint'] + if "activation_checkpoint" in node.meta: + act_ckpt_label = node.meta["activation_checkpoint"] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -150,7 +151,7 @@ def _find_ckpt_regions(nodes: List[Node]): current_region = act_ckpt_label start = idx end = -1 - elif current_region is not None and not 'activation_checkpoint' in node.meta: + elif current_region is not None and not "activation_checkpoint" in node.meta: # used to check the case below # node ckpt states = [ckpt, ckpt, non-ckpt] end = idx - 1 @@ -178,8 +179,8 @@ def _find_offload_regions(nodes: List[Node]): current_region = None for idx, node in enumerate(nodes): - if 'activation_offload' in node.meta and isinstance(node.meta['activation_offload'], Iterable): - act_offload_label = node.meta['activation_offload'] + if "activation_offload" in node.meta and isinstance(node.meta["activation_offload"], Iterable): + act_offload_label = node.meta["activation_offload"] if current_region == None: current_region = act_offload_label @@ -226,9 +227,9 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen """ Generate the checkpoint function call code text """ - outputs = ', '.join(output_vars) - inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' + outputs = ", ".join(output_vars) + inputs = ", ".join(input_vars) + return f"{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})" def _end_of_ckpt(node: Node, check_idx: int) -> bool: @@ -240,9 +241,9 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool: Returns: bool """ - if 'activation_checkpoint' in node.meta: - if isinstance(node.meta['activation_checkpoint'], list): - return node.meta['activation_checkpoint'][check_idx] == None + if "activation_checkpoint" in node.meta: + if isinstance(node.meta["activation_checkpoint"], list): + return node.meta["activation_checkpoint"][check_idx] == None else: return False else: @@ -260,11 +261,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0): current_region = None for idx, node in enumerate(nodes): - if 'activation_checkpoint' in node.meta: - if isinstance(node.meta['activation_checkpoint'], int): - act_ckpt_label = node.meta['activation_checkpoint'] + if "activation_checkpoint" in node.meta: + if isinstance(node.meta["activation_checkpoint"], int): + act_ckpt_label = node.meta["activation_checkpoint"] else: - act_ckpt_label = node.meta['activation_checkpoint'][check_idx] + act_ckpt_label = node.meta["activation_checkpoint"][check_idx] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -298,13 +299,9 @@ def _find_nested_ckpt_regions(nodes, check_idx=0): return ckpt_regions -def emit_ckpt_func(body, - ckpt_func, - node_list: List[Node], - emit_node_func, - delete_unused_value_func, - level=0, - in_ckpt=False): +def emit_ckpt_func( + body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, level=0, in_ckpt=False +): """Emit ckpt function in nested way Args: body: forward code, in recursive calls, this part will be checkpoint @@ -321,17 +318,17 @@ def emit_ckpt_func(body, inputs, outputs = _find_input_and_output_nodes(node_list) # if the current checkpoint function use int as label, using old generation method - if isinstance(node_list[0].meta['activation_checkpoint'], int): - label = node_list[0].meta['activation_checkpoint'] + if isinstance(node_list[0].meta["activation_checkpoint"], int): + label = node_list[0].meta["activation_checkpoint"] ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") for node in node_list: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = node_list[0].meta.get('activation_offload', False) + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") + activation_offload = node_list[0].meta.get("activation_offload", False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) usage += "\n" body.append(usage) @@ -340,12 +337,12 @@ def emit_ckpt_func(body, else: # label given by each layer, e.g. if you are currently at level [0, 1, 1] # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].meta['activation_checkpoint'][:level + 1]]) + label = "_".join([str(idx) for idx in node_list[0].meta["activation_checkpoint"][: level + 1]]) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") # if there is more level to fetch - if level + 1 < len(node_list[0].meta['activation_checkpoint']): + if level + 1 < len(node_list[0].meta["activation_checkpoint"]): ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] @@ -358,38 +355,45 @@ def emit_ckpt_func(body, break if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, - delete_unused_value_func, level + 1, True) + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func( + ckpt_func, + ckpt_func_buffer, + ckpt_node_list, + emit_node_func, + delete_unused_value_func, + level + 1, + True, + ) node_idx += len(ckpt_node_list) else: node = node_list[node_idx] emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) node_idx += 1 - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") ckpt_func += ckpt_func_buffer - activation_offload = node_list[0].meta.get('activation_offload', False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + activation_offload = node_list[0].meta.get("activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n" if in_ckpt: - usage = ' ' + usage + usage = " " + usage body.append(usage) # last level else: for node in node_list: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') - activation_offload = node_list[0].meta.get('activation_offload', False) - usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n") + activation_offload = node_list[0].meta.get("activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + "\n" if in_ckpt: - usage = ' ' + usage + usage = " " + usage body.append(usage) @@ -420,7 +424,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod # find the input and output var names for each offload region for idx, (start, end) in enumerate(offload_regions): - offload_node_list = node_list[start:end + 1] + offload_node_list = node_list[start : end + 1] inputs, outputs = _find_input_and_output_nodes(offload_node_list) offload_inputs.append(inputs) offload_outputs.append(outputs) @@ -436,7 +440,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod # process ckpt_regions if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1] emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) node_idx += len(ckpt_node_list) @@ -470,7 +474,7 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod if within_offload_region: emit_node_func(node, body) - body[-1] = ' ' + body[-1] + body[-1] = " " + body[-1] delete_unused_value_func(node, body) else: @@ -508,14 +512,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # find the input and output var names for each region for idx, (start, end) in enumerate(ckpt_regions): - ckpt_node_list = node_list[start:end + 1] + ckpt_node_list = node_list[start : end + 1] inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) input_vars.append(inputs) output_vars.append(outputs) # find the input and output var names for each offload region for idx, (start, end) in enumerate(offload_regions): - offload_node_list = node_list[start:end + 1] + offload_node_list = node_list[start : end + 1] inputs, outputs = _find_input_and_output_nodes(offload_node_list) offload_inputs.append(inputs) offload_outputs.append(outputs) @@ -527,7 +531,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if idx in start_idx: label = start_idx.index(idx) ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) - ckpt_func.append(f'{ckpt_fn_def}\n') + ckpt_func.append(f"{ckpt_fn_def}\n") within_ckpt_region = True if idx in offload_starts: @@ -559,12 +563,12 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # NOTE: currently we separate body and ckpt_func definition if within_ckpt_region: emit_node_func(node, ckpt_func) - ckpt_func[-1] = ' ' + ckpt_func[-1] + ckpt_func[-1] = " " + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) elif within_offload_region: emit_node_func(node, body) - body[-1] = ' ' + body[-1] + body[-1] = " " + body[-1] delete_unused_value_func(node, body) else: @@ -576,13 +580,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # generate return statement label = end_idx.index(idx) return_statement = _gen_ckpt_output(output_vars[label]) - return_statement = f' {return_statement}\n\n' + return_statement = f" {return_statement}\n\n" ckpt_func.append(return_statement) # we need to check if the checkpoint need to offload the input start_node_idx = start_idx[label] - if 'activation_offload' in node_list[start_node_idx].meta: - activation_offload = node_list[start_node_idx].meta['activation_offload'] + if "activation_offload" in node_list[start_node_idx].meta: + activation_offload = node_list[start_node_idx].meta["activation_offload"] else: activation_offload = False @@ -594,8 +598,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if input_node.op != "placeholder": non_leaf_input = 1 for user in input_node.users: - if 'activation_checkpoint' in user.meta: - if user.meta['activation_checkpoint'] == label: + if "activation_checkpoint" in user.meta: + if user.meta["activation_checkpoint"] == label: if user.op == "call_module": if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace @@ -610,7 +614,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, # generate checkpoint function call in a new line usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) - usage += '\n' + usage += "\n" body.append(usage) within_ckpt_region = False @@ -621,7 +625,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, if CODEGEN_AVAILABLE: class ActivationCheckpointCodeGen(CodeGen): - def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] @@ -629,7 +632,7 @@ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> Py wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -637,7 +640,7 @@ def add_global(name_hint: str, obj: Any): Graph, like functions or types. Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -662,16 +665,16 @@ def add_global(name_hint: str, obj: Any): def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): # This is a generic type, e.g. typing.List[torch.Tensor] origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, '__args__'): + if hasattr(o, "__args__"): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] @@ -690,19 +693,18 @@ def type_repr(o: Any): return add_global(typename, o) def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - def _get_repr(arg): # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): + if isinstance(arg, tuple) and hasattr(arg, "_fields"): qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" return repr(arg) - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' + return f"{args_s}, {kwargs_s}" return args_s or kwargs_s # Run through reverse nodes and record the first instance of a use @@ -728,90 +730,101 @@ def delete_unused_values(user: Node, body): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods: + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" body.append(self.generate_output(node.args[0])) return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in nodes): + if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) @@ -820,13 +833,13 @@ def emit_node(node: Node, body): # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" if self._body_transformer: body = self._body_transformer(body) @@ -837,11 +850,11 @@ def emit_node(node: Node, body): # as we need colossalai.utils.checkpoint, we need to import colossalai # in forward function prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = ''.join(ckpt_func) + prologue + prologue = "".join(ckpt_func) + prologue prologue = prologue - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" {wrap_stmts} {prologue} @@ -861,7 +874,7 @@ def python_code_with_activation_checkpoint(self, root_module: str, namespace: _N wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [''] + maybe_return_annotation: List[str] = [""] def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -869,7 +882,7 @@ def add_global(name_hint: str, obj: Any): Graph, like functions or types. Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -894,12 +907,12 @@ def add_global(name_hint: str, obj: Any): def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) # This is a generic type, e.g. typing.List[torch.Tensor] - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) @@ -934,84 +947,94 @@ def delete_unused_values(user: Node, body): not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {to_delete_str}\n') + to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"]) + body.append(f"; {to_delete_str}\n") else: - body.append('\n') + body.append("\n") # NOTE: we add a variable to distinguish body and ckpt_func def emit_node(node: Node, body): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' - if node.op == 'placeholder': + maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}" + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}") + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods: assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) return body.append( - f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}") return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" if self._pytree_info is None: - body.append(f'return {repr(node.args[0])}') + body.append(f"return {repr(node.args[0])}") else: - body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') + body.append(f"return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)") return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") # Modified for activation checkpointing ckpt_func = [] # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - if any(isinstance(node.meta.get('activation_checkpoint', None), Iterable) for node in self.nodes): + if any(isinstance(node.meta.get("activation_checkpoint", None), Iterable) for node in self.nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) else: emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) @@ -1020,33 +1043,34 @@ def emit_node(node: Node, body): # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') + body.append("pass\n") if self._pytree_info is not None: orig_args = self._pytree_info.orig_args - has_orig_self = (orig_args[0] == 'self') + has_orig_self = orig_args[0] == "self" if has_orig_self: - free_vars.insert(0, 'self') - if len(free_vars) > 0: # pytree has placeholders in it + free_vars.insert(0, "self") + if len(free_vars) > 0: # pytree has placeholders in it body.insert( 0, - f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") + f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n", + ) else: orig_args = free_vars if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" - ckpt_func = ''.join(ckpt_func) + ckpt_func = "".join(ckpt_func) # If the original function didn't have self as its first argument, we # would have added it. - if len(orig_args) == 0 or orig_args[0] != 'self': - orig_args.insert(0, 'self') - code = ''.join(body) - code = '\n'.join(' ' + line for line in code.split('\n')) + if len(orig_args) == 0 or orig_args[0] != "self": + orig_args.insert(0, "self") + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) # as we need colossalai.utils.checkpoint, we need to import colossalai # in forward function diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index ebb9975f27db..8429a9607f7a 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -1,32 +1,35 @@ import os import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn from torch.nn.modules.module import _addindent try: - from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen - from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall + from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen + from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen + COLOGM = True except: from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule + COLOGM = False if COLOGM: class ColoGraphModule(GraphModule): - - def __init__(self, - root: Union[torch.nn.Module, Dict[str, Any]], - graph: Graph, - class_name: str = 'GraphModule', - ckpt_codegen: bool = True): + def __init__( + self, + root: Union[torch.nn.Module, Dict[str, Any]], + graph: Graph, + class_name: str = "GraphModule", + ckpt_codegen: bool = True, + ): if ckpt_codegen: graph.set_codegen(ActivationCheckpointCodeGen()) super().__init__(root, graph, class_name) @@ -60,7 +63,7 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module='self') + python_code = self._graph.python_code(root_module="self") self._code = python_code.src # To split ckpt functions code and forward code @@ -83,8 +86,8 @@ def recompile(self) -> PythonCode: # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. cls_call = cls.__call__ if "__call__" in vars(cls) else None - if '_wrapped_call' not in vars(cls): - cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] def call_wrapped(self, *args, **kwargs): return self._wrapped_call(self, *args, **kwargs) @@ -108,7 +111,7 @@ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModul """ folder = Path(folder) Path(folder).mkdir(exist_ok=True) - torch.save(self.state_dict(), folder / 'state_dict.pt') + torch.save(self.state_dict(), folder / "state_dict.pt") tab = " " * 4 # we add import colossalai here @@ -125,7 +128,13 @@ def __init__(self): def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: safe_reprs = [ - nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, ] if type(module) in safe_reprs: return f"{module.__repr__()}" @@ -136,10 +145,10 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: for module_name, module in self.named_children(): module_str = _gen_model_repr(module_name, module) if module_str is None: - module_file = folder / f'{module_name}.pt' + module_file = folder / f"{module_name}.pt" torch.save(module, module_file) blobified_modules.append(module_name) - module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") module_str = f"torch.load(r'{module_file}') # {module_repr}" model_str += f"{tab*2}self.{module_name} = {module_str}\n" @@ -156,19 +165,20 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" model_str += f"{_addindent(self.code, 4)}\n" - module_file = folder / 'module.py' + module_file = folder / "module.py" module_file.write_text(model_str) - init_file = folder / '__init__.py' - init_file.write_text('from .module import *') + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") if len(blobified_modules) > 0: - warnings.warn("Was not able to save the following children modules as reprs -" - f"saved as pickled files instead: {blobified_modules}") + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) else: class ColoGraphModule(GraphModule): - - def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"): super().__init__(root, graph, class_name) diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py index 245ba5d776da..99c8faaa0cc6 100644 --- a/colossalai/fx/passes/adding_split_node_pass.py +++ b/colossalai/fx/passes/adding_split_node_pass.py @@ -1,8 +1,6 @@ import numpy as np import torch import tqdm -from torch.fx import symbolic_trace -from torch.fx.node import Node from colossalai.fx.passes.split_module import split_module @@ -29,15 +27,15 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): accumulate_bwd_flop = 0 block_nodes = [] for node in gm.graph.nodes: - if 'block_split' in node.name: + if "block_split" in node.name: continue accumulate_fwd_flop += node.fwd_flop accumulate_bwd_flop += node.bwd_flop if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop: with gm.graph.inserting_after(node): - block_node = gm.graph.create_node('call_function', block_split) - setattr(block_node, 'fwd_flop', accumulate_fwd_flop) - setattr(block_node, 'bwd_flop', accumulate_bwd_flop) + block_node = gm.graph.create_node("call_function", block_split) + setattr(block_node, "fwd_flop", accumulate_fwd_flop) + setattr(block_node, "bwd_flop", accumulate_bwd_flop) accumulate_fwd_flop = 0 accumulate_bwd_flop = 0 block_nodes.append(block_node) @@ -47,7 +45,7 @@ def construct_blocks(gm: torch.fx.GraphModule, limit=0.01): def remove_blocks(gm: torch.fx.GraphModule): for node in gm.graph.nodes: - if (node.op, node.target) == ('call_function', block_split): + if (node.op, node.target) == ("call_function", block_split): gm.graph.erase_node(node) @@ -55,8 +53,8 @@ def get_compute_costs(node_list): num_nodes = len(node_list) all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64) - for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0): - for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False): + for start in tqdm.tqdm(range(num_nodes), desc="start pos", position=0): + for end in tqdm.tqdm(range(start, num_nodes), desc="end pos", position=1, leave=False): selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)] all_compute_cost[start, end] = sum(selected_flops) @@ -78,12 +76,14 @@ def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_cost # record start node index for next stage in this partition f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32) f[0, num_nodes] = 0 - for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks - for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False): - for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False): + for s in tqdm.tqdm( + range(1, num_stages + 1), desc="stage", position=2, leave=False + ): # pylint: disable=too-many-nested-blocks + for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc="start node", position=3, leave=False): + for k in tqdm.tqdm(range(num_nodes, i, -1), desc="mid node", position=4, leave=False): stage_cost = compute_costs[i, k - 1] new_cost = f[s - 1, k] + stage_cost - if (stage_cost <= max_compute_cost and new_cost < f[s, i]): + if stage_cost <= max_compute_cost and new_cost < f[s, i]: f[s, i] = new_cost f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost) f_argmin[s, i] = k @@ -113,7 +113,7 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche best_cost = np.inf best_solution = None last_max_compute_cost = 0.0 - gap = 1e6 # temporary magic number, unit: flops + gap = 1e6 # temporary magic number, unit: flops for max_compute_cost in tqdm.tqdm(max_compute_costs): # Pruning to reduce search space. @@ -122,8 +122,9 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche if max_compute_cost - last_max_compute_cost < gap: continue - cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs, - max_compute_cost) + cost, solution = do_dp_split_gpipe_impl( + len(node_list), num_stages, num_microbatches, compute_costs, max_compute_cost + ) if cost < best_cost: best_cost = cost @@ -137,15 +138,15 @@ def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatche # split_mode: # 'node': fx_node # 'block': many fx_nodes construct a block -def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01): - assert mode in ['node', 'block'] +def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode="block", block_limit=0.01): + assert mode in ["node", "block"] # nodes or blocks will be used in partition. node_list = [] - if mode == 'node': + if mode == "node": for node in gm.graph.nodes: node_list.append(node) - elif mode == 'block': + elif mode == "block": node_list = construct_blocks(gm, limit=block_limit) else: pass @@ -154,16 +155,16 @@ def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches) - for (_, next_start_node) in best_solution: + for _, next_start_node in best_solution: if pp_size <= 1: break node = node_list[next_start_node] with gm.graph.inserting_before(node): - split_node = gm.graph.create_node('call_function', pipe_split) + split_node = gm.graph.create_node("call_function", pipe_split) pp_size -= 1 # remove block node if possible - if mode == 'block': + if mode == "block": remove_blocks(gm) gm.recompile() @@ -178,7 +179,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): # To use avgcompute_split_pass, we need run meta_info_prop interpreter first. # If nodes don't have meta info, this pass will fall back to normal balanced split pass. check_node = list(mod_graph.nodes)[0] - if 'tensor_meta' not in check_node.meta: + if "tensor_meta" not in check_node.meta: return balanced_split_pass(gm, pp_size) total_fwd_flop = 0 @@ -190,7 +191,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): for node in mod_graph.nodes: if pp_size <= 1: break - if 'pipe_split' in node.name: + if "pipe_split" in node.name: continue accumulate_fwd_flop += node.fwd_flop if accumulate_fwd_flop >= partition_flop: @@ -199,7 +200,7 @@ def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 partition_flop = total_fwd_flop // pp_size with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -218,12 +219,12 @@ def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int): if accumulate_num_node >= avg_num_node: accumulate_num_node = 0 pp_size -= 1 - if node.next.op == 'output': + if node.next.op == "output": with mod_graph.inserting_before(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) else: with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -250,18 +251,18 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 # If the next node is output node, we will insert split annotation before # node to make sure there is at least one node in last partition. - if node.next.op == 'output': + if node.next.op == "output": with mod_graph.inserting_before(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) else: with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) if pp_size > 1: node_counter = 0 for node in mod_graph.nodes: if pp_size <= 1: break - if node.op == 'placeholder': + if node.op == "placeholder": continue elif node_counter == 0: node_counter += 1 @@ -269,7 +270,7 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 node_counter = 0 with mod_graph.inserting_before(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -283,7 +284,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): # To use balanced_split_pass_v2, we need run meta_info_prop interpreter first. # If nodes don't have meta info, this pass will fall back to normal balanced split pass. check_node = list(mod_graph.nodes)[0] - if 'tensor_meta' not in check_node.meta: + if "tensor_meta" not in check_node.meta: return balanced_split_pass(gm, pp_size) total_element_size = 0 @@ -295,7 +296,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): for node in mod_graph.nodes: if pp_size <= 1: break - if 'pipe_split' in node.name: + if "pipe_split" in node.name: continue accumulate_node_size += node.node_size if accumulate_node_size >= partition_size: @@ -304,7 +305,7 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int): pp_size -= 1 partition_size = total_element_size // pp_size with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -333,7 +334,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int): accumulate_layer_amount = 0 pp_size -= 1 with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm @@ -346,7 +347,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule, merge_output def split_callback(n: torch.fx.Node): nonlocal part_idx - if (n.op, n.target) == ('call_function', pipe_split): + if (n.op, n.target) == ("call_function", pipe_split): part_idx += 1 return part_idx @@ -355,7 +356,7 @@ def split_callback(n: torch.fx.Node): for name, submodule in split_mod.named_modules(): if isinstance(submodule, torch.fx.GraphModule): for node in submodule.graph.nodes: - if (node.op, node.target) == ('call_function', pipe_split): + if (node.op, node.target) == ("call_function", pipe_split): submodule.graph.erase_node(node) submodule.recompile() split_submodules.append(submodule) diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py index 81ac64205528..5440a4eadbbf 100644 --- a/colossalai/fx/passes/concrete_info_prop.py +++ b/colossalai/fx/passes/concrete_info_prop.py @@ -1,5 +1,5 @@ from dataclasses import asdict -from typing import Any, Dict, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.fx @@ -85,10 +85,10 @@ def run_node(self, n: Node) -> Any: self._is_proped = True result, meta_info = super().run_node(n) - n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) - n.meta['type'] = type(result) + setattr(n, "node_size", n.meta.get("fwd_mem_tmp", 0) + n.meta.get("fwd_mem_out", 0)) + n.meta["type"] = type(result) # retain the autograd graph for param in self.module.parameters(): @@ -98,7 +98,7 @@ def run_node(self, n: Node) -> Any: # Main Node running APIs @compatibility(is_backward_compatible=True) - def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: ``Interpreter`` maintains an internal iterator over @@ -119,7 +119,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return super().placeholder(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute value from the ``Module`` hierarchy of ``self.module``. @@ -138,7 +138,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st return super().get_attr(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node with meta tensor and return the result and its meta profile. @@ -157,7 +157,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di return profile_function(target, self.device)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node with meta tensor and return the result and its meta profile. @@ -175,7 +175,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return profile_method(target, self.device)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_module`` node with meta tensor and return the result and its meta profile. @@ -197,7 +197,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return profile_module(submod, self.device)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute an ``output`` node. This really just retrieves the value referenced by the ``output`` node and returns it. @@ -228,7 +228,7 @@ def propagate(self, *args): """ return self.run(*args) - def summary(self, unit: str = 'MB') -> str: + def summary(self, unit: str = "MB") -> str: """ Summarizes the memory and FLOPs statistics of the `GraphModule` in tabular format. Note that this API requires the ``tabulate`` module @@ -238,9 +238,11 @@ def summary(self, unit: str = 'MB') -> str: try: from tabulate import tabulate except ImportError: - print("`summary` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." @@ -249,10 +251,10 @@ def summary(self, unit: str = 'MB') -> str: def mem_repr(mem: int) -> str: unit_divisor_map = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, + "kb": 1024, + "mb": 1024**2, + "gb": 1024**3, + "tb": 1024**4, } return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" @@ -261,30 +263,32 @@ def time_repr(time: float): for node in self.module.graph.nodes: node: Node - node_summaries.append([ - node.op, - str(node), - time_repr(node.meta['fwd_time']), - time_repr(node.meta['bwd_time']), - node.meta['save_fwd_in'], - mem_repr(node.meta['fwd_mem_out']), - mem_repr(node.meta['fwd_mem_tmp']), - mem_repr(node.meta['bwd_mem_out']), - mem_repr(node.meta['bwd_mem_tmp']), - ]) + node_summaries.append( + [ + node.op, + str(node), + time_repr(node.meta["fwd_time"]), + time_repr(node.meta["bwd_time"]), + node.meta["save_fwd_in"], + mem_repr(node.meta["fwd_mem_out"]), + mem_repr(node.meta["fwd_mem_tmp"]), + mem_repr(node.meta["bwd_mem_out"]), + mem_repr(node.meta["bwd_mem_tmp"]), + ] + ) # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers: List[str] = [ - 'Op type', - 'Op', - 'Forward time', - 'Backward time', - 'SAVE_FWD_IN', - 'FWD_OUT', - 'FWD_TMP', - 'BWD_OUT', - 'BWD_TMP', + "Op type", + "Op", + "Forward time", + "Backward time", + "SAVE_FWD_IN", + "FWD_OUT", + "FWD_TMP", + "BWD_OUT", + "BWD_TMP", ] - return tabulate(node_summaries, headers=headers, stralign='right') + return tabulate(node_summaries, headers=headers, stralign="right") diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py index 4571bd93a790..3d032a27db63 100644 --- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py +++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py @@ -1,14 +1,11 @@ -import torch -from typing import List -from torch.fx import symbolic_trace -from torch.fx.node import Node -from colossalai.fx.passes.split_module import split_module -from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec import builtins import operator -from copy import deepcopy +from typing import List + +import torch + +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec def apply(*args, **kwargs): @@ -24,16 +21,16 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi origin_node_sharding_spec_dict = {} for node_index, (node, strategy_index) in enumerate(zip(nodes, solution)): strategies_vector = node.strategies_vector - setattr(node, 'best_strategy', strategies_vector[strategy_index]) - setattr(node, 'sharding_spec', strategies_vector[strategy_index].output_sharding_spec) + setattr(node, "best_strategy", strategies_vector[strategy_index]) + setattr(node, "sharding_spec", strategies_vector[strategy_index].output_sharding_spec) origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].output_sharding_spec # apply the sharding spec of parameters for node in nodes: - if node.op == 'call_module': + if node.op == "call_module": target_module = node.graph.owning_module.get_submodule(node.target) origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {}) - setattr(target_module.weight, 'sharding_spec', origin_sharding_spec) + setattr(target_module.weight, "sharding_spec", origin_sharding_spec) target_weight_sharding_spec = node.best_strategy.input_shardings[1] target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) apply(target_module.weight, target_weight_sharding_spec) @@ -51,10 +48,10 @@ def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int], devi # add above dicts into graph for node in nodes: - if node.op != 'placeholder': + if node.op != "placeholder": with mod_graph.inserting_before(node): - input_specs_node = mod_graph.create_node('placeholder', target='sharding_spec_convert_dict') - origin_specs_node = mod_graph.create_node('placeholder', target='origin_node_sharding_spec_dict') + input_specs_node = mod_graph.create_node("placeholder", target="sharding_spec_convert_dict") + origin_specs_node = mod_graph.create_node("placeholder", target="origin_node_sharding_spec_dict") break return sharding_spec_convert_dict, origin_node_sharding_spec_dict @@ -70,13 +67,13 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): node_to_index_dict = {} index = 0 for node in nodes: - if node.target == 'sharding_spec_convert_dict': + if node.target == "sharding_spec_convert_dict": input_dict_node = node continue - if node.target == 'origin_node_sharding_spec_dict': + if node.target == "origin_node_sharding_spec_dict": origin_dict_node = node continue - if not hasattr(node, 'best_strategy'): + if not hasattr(node, "best_strategy"): continue node_to_index_dict[node] = index index += 1 @@ -84,28 +81,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule): # add shape consistency apply function into graph for node in nodes: - if not hasattr(node, 'best_strategy'): + if not hasattr(node, "best_strategy"): continue with mod_graph.inserting_after(node): - origin_spec_node = mod_graph.create_node('call_function', - operator.getitem, - args=(origin_dict_node, node_to_index_dict[node])) + origin_spec_node = mod_graph.create_node( + "call_function", operator.getitem, args=(origin_dict_node, node_to_index_dict[node]) + ) with mod_graph.inserting_after(origin_spec_node): - set_sharding_spec_node = mod_graph.create_node('call_function', - builtins.setattr, - args=(node, 'sharding_spec', origin_spec_node)) + set_sharding_spec_node = mod_graph.create_node( + "call_function", builtins.setattr, args=(node, "sharding_spec", origin_spec_node) + ) for user_node in node.strategies_vector.successor_nodes: node_index = user_node.strategies_vector.predecessor_nodes.index(node) with mod_graph.inserting_before(user_node): - input_specs_node = mod_graph.create_node('call_function', - operator.getitem, - args=(input_dict_node, node_to_index_dict[node])) + input_specs_node = mod_graph.create_node( + "call_function", operator.getitem, args=(input_dict_node, node_to_index_dict[node]) + ) with mod_graph.inserting_before(user_node): - sharding_spec_node = mod_graph.create_node('call_function', - operator.getitem, - args=(input_specs_node, node_index)) + sharding_spec_node = mod_graph.create_node( + "call_function", operator.getitem, args=(input_specs_node, node_index) + ) with mod_graph.inserting_before(user_node): - shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node)) + shape_consistency_node = mod_graph.create_node("call_function", apply, args=(node, sharding_spec_node)) return gm diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index ab203dfd7440..1720aa58da2b 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -109,13 +109,13 @@ def extract_tensor_meta(obj): return TensorMetadata(None, None, False, None, 0, False) tensor_meta = tree_map(extract_tensor_meta, result) - n.meta['tensor_meta'] = tensor_meta - n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` + n.meta["tensor_meta"] = tensor_meta + n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` # TODO: the attribute node_size should be removed in the future - setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0))) - setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0)) - setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0)) - n.meta['type'] = type(result) + setattr(n, "node_size", activation_size(n.meta.get("fwd_out", 0)) + activation_size(n.meta.get("fwd_tmp", 0))) + setattr(n, "fwd_flop", n.meta.get("fwd_flop", 0)) + setattr(n, "bwd_flop", n.meta.get("bwd_flop", 0)) + n.meta["type"] = type(result) # retain the autograd graph for param in self.module.parameters(): @@ -125,7 +125,7 @@ def extract_tensor_meta(obj): # Main Node running APIs @compatibility(is_backward_compatible=True) - def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def placeholder(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: ``Interpreter`` maintains an internal iterator over @@ -146,7 +146,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return super().placeholder(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def get_attr(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute value from the ``Module`` hierarchy of ``self.module``. @@ -165,7 +165,7 @@ def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[st return super().get_attr(target, args, kwargs), GraphInfo() @compatibility(is_backward_compatible=True) - def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_function`` node with meta tensor and return the result and its meta profile. @@ -184,7 +184,7 @@ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Di return profile_function(target)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_method`` node with meta tensor and return the result and its meta profile. @@ -202,7 +202,7 @@ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return profile_method(target)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute a ``call_module`` node with meta tensor and return the result and its meta profile. @@ -224,7 +224,7 @@ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict return profile_module(submod)(*args, **kwargs) @compatibility(is_backward_compatible=True) - def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + def output(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: """ Execute an ``output`` node. This really just retrieves the value referenced by the ``output`` node and returns it. @@ -240,7 +240,7 @@ def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, result (Any): The argument value that was retrieved meta_info (MetaInfo): The memory cost and FLOPs estimated with `MetaTensor`. """ - if hasattr(args[0], '_tensor'): + if hasattr(args[0], "_tensor"): return args[0], GraphInfo(fwd_in=[args[0]._tensor]) return args[0], GraphInfo(save_fwd_in=True) @@ -257,7 +257,7 @@ def propagate(self, *args): """ return super().run(*args) - def summary(self, unit: str = 'MB') -> str: + def summary(self, unit: str = "MB") -> str: """ Summarizes the memory and FLOPs statistics of the `GraphModule` in tabular format. Note that this API requires the ``tabulate`` module @@ -267,9 +267,11 @@ def summary(self, unit: str = 'MB') -> str: try: from tabulate import tabulate except ImportError: - print("`summary` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`summary` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) assert self._is_proped, "Please call `interp.run(input)` before calling `interp.summary()`." @@ -278,10 +280,10 @@ def summary(self, unit: str = 'MB') -> str: def mem_repr(mem: int) -> str: unit_divisor_map = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, + "kb": 1024, + "mb": 1024**2, + "gb": 1024**3, + "tb": 1024**4, } return f"{mem / unit_divisor_map[unit.lower()]:.2f} {unit.upper()}" @@ -292,35 +294,37 @@ def flops_repr(flop: int) -> str: for node in self.module.graph.nodes: node: Node accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node) - node_summaries.append([ - node.op, - str(node), - flops_repr(node.meta['fwd_flop']), - flops_repr(node.meta['bwd_flop']), - mem_repr(accumulate_size), - mem_repr(calculate_fwd_in(node)), - mem_repr(calculate_fwd_out(node)), - mem_repr(calculate_fwd_tmp(node)), - mem_repr(node.meta['bwd_mem_out']), - mem_repr(node.meta['bwd_mem_tmp']), - ]) + node_summaries.append( + [ + node.op, + str(node), + flops_repr(node.meta["fwd_flop"]), + flops_repr(node.meta["bwd_flop"]), + mem_repr(accumulate_size), + mem_repr(calculate_fwd_in(node)), + mem_repr(calculate_fwd_out(node)), + mem_repr(calculate_fwd_tmp(node)), + mem_repr(node.meta["bwd_mem_out"]), + mem_repr(node.meta["bwd_mem_tmp"]), + ] + ) # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers: List[str] = [ - 'Op type', - 'Op', - 'Forward FLOPs', - 'Backward FLOPs', - 'Accumulated Memory', - 'FWD_IN', - 'FWD_OUT', - 'FWD_TMP', - 'BWD_OUT', - 'BWD_TMP', + "Op type", + "Op", + "Forward FLOPs", + "Backward FLOPs", + "Accumulated Memory", + "FWD_IN", + "FWD_OUT", + "FWD_TMP", + "BWD_OUT", + "BWD_TMP", ] - return tabulate(node_summaries, headers=headers, stralign='right') + return tabulate(node_summaries, headers=headers, stralign="right") def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: str = "MB", **kwargs) -> None: @@ -344,15 +348,16 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit: Returns: torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo. """ - device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") interp = MetaInfoProp(gm.to(device)) if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor + args = tree_map(lambda x: MetaTensor(x, fake_device=device), args) kwargs = tree_map(lambda x: MetaTensor(x, fake_device=device), kwargs) interp.propagate(*args, **kwargs) if verbose: interp.summary(unit) - gm.to('cpu') + gm.to("cpu") del interp return gm diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py index efdd34a01fe0..73379f73689c 100644 --- a/colossalai/fx/passes/passes_for_gpt2_test.py +++ b/colossalai/fx/passes/passes_for_gpt2_test.py @@ -5,7 +5,6 @@ from packaging import version from torch.fx._compatibility import compatibility from torch.fx.graph_module import GraphModule -from torch.fx.node import Node from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split from colossalai.fx.passes.meta_info_prop import TensorMetadata @@ -13,9 +12,9 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]): - ''' + """ This pass is only used to do the gpt2 performance test, it may move into adding_split_node_pass.py, and will be deprecated in future. - ''' + """ mod_graph = gm.graph valid_children_size = 0 valid_children = [] @@ -39,40 +38,40 @@ def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, parti part_index += 1 pp_size -= 1 with mod_graph.inserting_after(node): - split_node = mod_graph.create_node('call_function', pipe_split) + split_node = mod_graph.create_node("call_function", pipe_split) gm.recompile() return gm def split_with_split_nodes_pass_for_gp2_test(annotated_gm: torch.fx.GraphModule): - ''' + """ This pass will be used in gpt2 test, only a part of changes may be added into split_with_split_nodes_pass, and it will be deprecated in future. - ''' + """ part_idx = 0 def eliminate_unused_placeholders(gm): for node in gm.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": if not len(node.users): gm.graph.erase_node(node) gm.recompile() return gm def refill_outputs_and_placeholders(gm, next_partition_placeholders): - ''' + """ This method is used to eliminate the outputs in previous partition which is unused in next partition. In split module pass, it treats partitions as a DAG, but we need treat them as a single direction linked list in pipeline parallel. The difference is if a output from partition 0 is an input argument of partition 3, the DAG will not transfer it to partition 1 and partition 2. However, in single direction linked list, we need to do so. - ''' + """ output_type = None output_args = [] non_output_list = [] new_placeholder_list = [] for node in gm.graph.nodes: - if node.op == 'output': + if node.op == "output": if isinstance(node.args[0], (tuple, list)): output_type = node.args[0].__class__ output_args.extend([n.name for n in node.args[0]]) @@ -114,7 +113,7 @@ def refill_outputs_and_placeholders(gm, next_partition_placeholders): continue for node in gm.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": new_placeholder_list.append(node.name) if output_type is not None: gm.graph.output(output_type(output_args)) @@ -125,7 +124,7 @@ def refill_outputs_and_placeholders(gm, next_partition_placeholders): def split_callback(n: torch.fx.Node): nonlocal part_idx - if (n.op, n.target) == ('call_function', pipe_split): + if (n.op, n.target) == ("call_function", pipe_split): part_idx += 1 return part_idx @@ -134,7 +133,7 @@ def split_callback(n: torch.fx.Node): for name, submodule in split_mod.named_modules(): if isinstance(submodule, torch.fx.GraphModule): for node in submodule.graph.nodes: - if (node.op, node.target) == ('call_function', pipe_split): + if (node.op, node.target) == ("call_function", pipe_split): submodule.graph.erase_node(node) submodule.recompile() split_submodules.append(submodule) @@ -200,13 +199,12 @@ def _gen_all_ancestors_set(node): _gen_all_ancestors_set(node) for n in list(all_ancestors): - if n.op != 'placeholder' and n._fx_partition > partition_name: + if n.op != "placeholder" and n._fx_partition > partition_name: n._fx_partition = partition_name - def record_cross_partition_use(def_node: torch.fx.node.Node, - use_node: Optional[torch.fx.node.Node]): # noqa: B950 - def_partition_name = getattr(def_node, '_fx_partition', None) - use_partition_name = getattr(use_node, '_fx_partition', None) + def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: # if 'tensor_meta' in def_node.meta: # if not _node_with_all_tensor_element(def_node.meta['tensor_meta']): @@ -237,7 +235,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, if node.op in ["placeholder"]: continue - if node.op == 'output': + if node.op == "output": # partition_name = str(split_callback(node)) # def _set_output_args_partition(n, partition_name): # n._fx_partition = partition_name @@ -252,12 +250,12 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, partitions[partition_name] = partition = Partition(partition_name) partition.node_names.append(node.name) - origin_partition_name = getattr(node, '_fx_partition', None) + origin_partition_name = getattr(node, "_fx_partition", None) if origin_partition_name is None: node._fx_partition = partition_name torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) - torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 # find partitions with no dependencies root_partitions: List[str] = [] @@ -287,7 +285,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, # Transform nodes and collect targets for partition's submodule for node in m.graph.nodes: - if hasattr(node, '_fx_partition'): + if hasattr(node, "_fx_partition"): partition = partitions[node._fx_partition] # swap out old graph nodes in kw/args with references to new nodes in this submodule @@ -295,26 +293,24 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) - if node.op not in ['call_module', 'get_attr']: + if node.op not in ["call_module", "get_attr"]: target = node.target else: - target_atoms = node.target.split('.') + target_atoms = node.target.split(".") target_attr = m for atom in target_atoms: if not hasattr(target_attr, atom): - raise RuntimeError(f'Operator target {node.target} not found!') + raise RuntimeError(f"Operator target {node.target} not found!") target_attr = getattr(target_attr, atom) # target = target_atoms[-1] - target = '_'.join(target_atoms) + target = "_".join(target_atoms) partition.targets[target] = target_attr assert isinstance(gathered_args, tuple) assert isinstance(gathered_kwargs, dict) - new_node = partition.graph.create_node(op=node.op, - target=target, - args=gathered_args, - kwargs=gathered_kwargs, - name=node.name) + new_node = partition.graph.create_node( + op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs, name=node.name + ) new_node.meta = node.meta.copy() partition.environment[node] = new_node @@ -323,14 +319,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} for node in m.graph.nodes: - if node.op == 'placeholder': - if version.parse(torch.__version__) < version.parse('1.11.0'): + if node.op == "placeholder": + if version.parse(torch.__version__) < version.parse("1.11.0"): base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) else: default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty - base_mod_env[node.name] = base_mod_graph.placeholder(node.name, - type_expr=node.type, - default_value=default_value) + base_mod_env[node.name] = base_mod_graph.placeholder( + node.name, type_expr=node.type, default_value=default_value + ) base_mod_env[node.name].meta = node.meta.copy() # Do some things iterating over the partitions in topological order again: @@ -344,13 +340,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, # Set correct output values output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) - output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] partition.graph.output(output_vals) # Construct GraphModule for this partition - submod_name = f'submod_{partition_name}' - base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, - partition.graph) # noqa: B950 + submod_name = f"submod_{partition_name}" + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 # Emit call in base graph to this submodule output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) @@ -358,14 +355,14 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, # Unpack multiple return values from submodule output_val_proxy = torch.fx.proxy.Proxy(output_val) for i, output_name in enumerate(partition.outputs): - base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] else: if not partition.outputs: continue base_mod_env[list(partition.outputs)[0]] = output_val for node in m.graph.nodes: - if node.op == 'output': - base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + if node.op == "output": + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index ccbab0c38a29..be8261f2a3f4 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -9,8 +9,19 @@ ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_FUNC_OP = [ - torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, - operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout + torch.add, + operator.add, + torch.abs, + torch.cos, + torch.exp, + torch.mul, + operator.mul, + operator.floordiv, + operator.truediv, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, ] @@ -72,7 +83,7 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): # traverse the graph to look for consecutive linear layers is_linear_module = False - if node.op == 'call_module': + if node.op == "call_module": # look for the linear layer module = node.graph.owning_module.get_submodule(node.target) if isinstance(module, nn.Linear): @@ -82,31 +93,31 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): # it means the first linear has been found and the current module # is the second linear # set the current linear module to be row-sharded - annotation_record['row'] = module + annotation_record["row"] = module for shard_type, module in annotation_record.items(): # add row sharding spec - if shard_type == 'row': + if shard_type == "row": dist_spec = ShardSpec(dims=[-1], num_partitions=[world_size]) comp_spec = ComputeSpec(ComputePattern.TP1D) - setattr(module.weight, 'pg', process_group) - setattr(module.weight, 'dist_spec', dist_spec) - setattr(module.weight, 'comp_spec', comp_spec) - elif shard_type == 'col': + setattr(module.weight, "pg", process_group) + setattr(module.weight, "dist_spec", dist_spec) + setattr(module.weight, "comp_spec", comp_spec) + elif shard_type == "col": weight_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) weight_comp_spec = ComputeSpec(ComputePattern.TP1D) weight_comp_spec.output_replicate = False - setattr(module.weight, 'pg', process_group) - setattr(module.weight, 'dist_spec', weight_dist_spec) - setattr(module.weight, 'comp_spec', weight_comp_spec) + setattr(module.weight, "pg", process_group) + setattr(module.weight, "dist_spec", weight_dist_spec) + setattr(module.weight, "comp_spec", weight_comp_spec) if module.bias is not None: bias_dist_spec = ShardSpec(dims=[0], num_partitions=[world_size]) bias_comp_spec = ComputeSpec(ComputePattern.TP1D) bias_comp_spec.output_replicate = False - setattr(module.bias, 'pg', process_group) - setattr(module.bias, 'dist_spec', bias_dist_spec) - setattr(module.bias, 'comp_spec', bias_comp_spec) + setattr(module.bias, "pg", process_group) + setattr(module.bias, "dist_spec", bias_dist_spec) + setattr(module.bias, "comp_spec", bias_comp_spec) start_tracking = False annotation_record.clear() else: @@ -114,16 +125,16 @@ def _traverse_and_annotate(node, start_tracking, annotation_record, world_size): # it means the current layer is the first linear # set the linear layer to be col-sharded start_tracking = True - annotation_record['col'] = module + annotation_record["col"] = module if start_tracking and not is_linear_module: # check against the white list # if non-element wise op is found, we reset the tracking - if node.op == 'call_module': + if node.op == "call_module": module = node.graph.owning_module.get_submodule(node.target) if module.__class__ not in ELEMENTWISE_MODULE_OP: start_tracking = False - elif node.op == 'call_function' or node.op == 'call_method': + elif node.op == "call_function" or node.op == "call_method": if node.target not in ELEMENTWISE_FUNC_OP: start_tracking = False elif len(node.users.keys()) > 1: diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index 61ed037ab7a1..67a2432595d6 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -25,12 +25,14 @@ def __init__(self, name: str): self.targets: Dict[str, Any] = {} def __repr__(self) -> str: - return f"name: {self.name},\n" \ - f" nodes: {self.node_names},\n" \ - f" inputs: {self.inputs},\n" \ - f" outputs: {self.outputs},\n" \ - f" partitions dependent on: {self.partitions_dependent_on},\n" \ + return ( + f"name: {self.name},\n" + f" nodes: {self.node_names},\n" + f" inputs: {self.inputs},\n" + f" outputs: {self.outputs},\n" + f" partitions dependent on: {self.partitions_dependent_on},\n" f" partition dependents: {self.partition_dependents}" + ) # Creates subgraphs out of main graph @@ -117,10 +119,9 @@ def forward(self, x, y): partitions: Dict[str, Partition] = {} orig_nodes: Dict[str, torch.fx.node.Node] = {} - def record_cross_partition_use(def_node: torch.fx.node.Node, - use_node: Optional[torch.fx.node.Node]): # noqa: B950 - def_partition_name = getattr(def_node, '_fx_partition', None) - use_partition_name = getattr(use_node, '_fx_partition', None) + def record_cross_partition_use(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def_partition_name = getattr(def_node, "_fx_partition", None) + use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: if def_partition_name is not None: def_partition = partitions[def_partition_name] @@ -134,7 +135,7 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, if def_partition_name is not None: use_partition.partitions_dependent_on.setdefault(def_partition_name) - def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 + def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 def_partition_name = getattr(def_node, "_fx_partition", None) use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: @@ -161,7 +162,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node if node.op in ["placeholder"]: continue - if node.op == 'output': + if node.op == "output": if merge_output: torch.fx.graph.map_arg(node.args[0], lambda n: record_output(n, node.prev)) else: @@ -178,7 +179,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node node._fx_partition = partition_name torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) - torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 + torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950 # find partitions with no dependencies root_partitions: List[str] = [] @@ -208,7 +209,7 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node # Transform nodes and collect targets for partition's submodule for node in m.graph.nodes: - if hasattr(node, '_fx_partition'): + if hasattr(node, "_fx_partition"): partition = partitions[node._fx_partition] # swap out old graph nodes in kw/args with references to new nodes in this submodule @@ -216,25 +217,24 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n: environment[n]) - if node.op not in ['call_module', 'get_attr']: + if node.op not in ["call_module", "get_attr"]: target = node.target else: - target_atoms = node.target.split('.') + target_atoms = node.target.split(".") target_attr = m for atom in target_atoms: if not hasattr(target_attr, atom): - raise RuntimeError(f'Operator target {node.target} not found!') + raise RuntimeError(f"Operator target {node.target} not found!") target_attr = getattr(target_attr, atom) # target = target_atoms[-1] - target = '_'.join(target_atoms) + target = "_".join(target_atoms) partition.targets[target] = target_attr assert isinstance(gathered_args, tuple) assert isinstance(gathered_kwargs, dict) - new_node = partition.graph.create_node(op=node.op, - target=target, - args=gathered_args, - kwargs=gathered_kwargs) + new_node = partition.graph.create_node( + op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs + ) new_node.meta = node.meta.copy() partition.environment[node] = new_node @@ -243,14 +243,14 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} for node in m.graph.nodes: - if node.op == 'placeholder': - if version.parse(torch.__version__) < version.parse('1.11.0'): + if node.op == "placeholder": + if version.parse(torch.__version__) < version.parse("1.11.0"): base_mod_env[node.name] = base_mod_graph.placeholder(node.target, type_expr=node.type) else: default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty - base_mod_env[node.name] = base_mod_graph.placeholder(node.target, - type_expr=node.type, - default_value=default_value) + base_mod_env[node.name] = base_mod_graph.placeholder( + node.target, type_expr=node.type, default_value=default_value + ) base_mod_env[node.name].meta = node.meta.copy() # Do some things iterating over the partitions in topological order again: @@ -264,13 +264,14 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node # Set correct output values output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs) - output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] + output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment] partition.graph.output(output_vals) # Construct GraphModule for this partition - submod_name = f'submod_{partition_name}' - base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, - partition.graph) # noqa: B950 + submod_name = f"submod_{partition_name}" + base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 # Emit call in base graph to this submodule output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs)) @@ -278,15 +279,15 @@ def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node # Unpack multiple return values from submodule output_val_proxy = torch.fx.proxy.Proxy(output_val) for i, output_name in enumerate(partition.outputs): - base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] else: if not partition.outputs: continue base_mod_env[list(partition.outputs)[0]] = output_val for node in m.graph.nodes: - if node.op == 'output': - base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 + if node.op == "output": + base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])) # noqa: B950 for partition_name in sorted_partitions: partition = partitions[partition_name] diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index bb4f3cd6a490..c51f49a30e8a 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -1,7 +1,9 @@ -import torch from typing import Dict -from torch.fx.node import Node, map_arg + +import torch from torch.fx.graph import Graph +from torch.fx.node import Node, map_arg + def get_comm_size(prev_partition, next_partition): """ @@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition): map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) for n in input_nodes: if n.name in parent_node_names and n not in visited_nodes: - comm_size += n.meta['tensor_meta'].numel + comm_size += n.meta["tensor_meta"].numel visited_nodes.add(n) return comm_size @@ -36,12 +38,12 @@ def get_leaf(graph: Graph): """ input_nodes: Dict[Node, None] = {} for node in graph.nodes: - if node.op == 'output': + if node.op == "output": map_arg(node.args, lambda n: input_nodes.setdefault(n)) map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) placeholder_nodes = [] for node in input_nodes.keys(): - if node.op == 'placeholder': + if node.op == "placeholder": placeholder_nodes.append(node) for node in placeholder_nodes: input_nodes.pop(node) @@ -60,13 +62,13 @@ def get_top(graph: Graph): """ top_node_list = set() for node in graph.nodes: - if node.op == 'output': + if node.op == "output": continue is_top = False def _get_top(node): nonlocal is_top - if node.op == 'placeholder': + if node.op == "placeholder": is_top = True map_arg(node.args, lambda n: _get_top(n)) @@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node): def get_all_consumers(graph: Graph, node: Node): """ Given a graph and a node of this graph, return all consumers of the node. - + Returns: List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. """ @@ -120,7 +122,7 @@ def forward(self, x): for node in gm.graph.nodes: if hasattr(node, 'bfs_level'): print(node.name, node.bfs_level) - + Output: graph(): %x : [#users=2] = placeholder[target=x] @@ -148,7 +150,7 @@ def forward(self, x): while nodes_to_process: new_process_list = [] for node in nodes_to_process: - if node.op == 'output': + if node.op == "output": continue node.bfs_level = current_level new_process_list.extend(get_all_consumers(graph, node)) @@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module: torch.nn.Module: the module associated with the given node """ - assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object' - assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' + assert ( + node.graph.owning_module is not None + ), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object" + assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}" module = node.graph.owning_module.get_submodule(node.target) return module - diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 8bcbde0eb23b..89dd2b3df617 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -12,7 +12,16 @@ ) from .tensor import MetaTensor else: - from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out + from .experimental import ( + meta_profiler_function, + meta_profiler_module, + profile_function, + profile_method, + profile_module, + calculate_fwd_in, + calculate_fwd_tmp, + calculate_fwd_out, + ) from .dataflow import GraphInfo from .memory_utils import activation_size, is_inplace, parameter_size diff --git a/colossalai/fx/profiler/constants.py b/colossalai/fx/profiler/constants.py index 5763a46dc83f..fad9bb272bff 100644 --- a/colossalai/fx/profiler/constants.py +++ b/colossalai/fx/profiler/constants.py @@ -1,6 +1,6 @@ import torch -__all__ = ['ALIAS_ATEN', 'INPLACE_NEW', 'INPLACE_MATH_ATEN', 'CLONE_ATEN', 'RELU_LIKE_OPS', 'RELU_LIKE_MOD'] +__all__ = ["ALIAS_ATEN", "INPLACE_NEW", "INPLACE_MATH_ATEN", "CLONE_ATEN", "RELU_LIKE_OPS", "RELU_LIKE_MOD"] aten = torch.ops.aten diff --git a/colossalai/fx/profiler/dataflow.py b/colossalai/fx/profiler/dataflow.py index a5e8880322b8..05f9b50ce575 100644 --- a/colossalai/fx/profiler/dataflow.py +++ b/colossalai/fx/profiler/dataflow.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field from enum import Enum -from functools import partial from typing import Dict, List from torch.fx import Graph, Node @@ -69,8 +68,8 @@ class GraphInfo: def is_phase(n: Node, phase: Phase) -> bool: - assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' - return n.meta['phase'] == phase + assert "phase" in n.meta, f"Node meta of {n} has no key `phase`!" + return n.meta["phase"] == phase @compatibility(is_backward_compatible=False) @@ -103,9 +102,9 @@ def _peak_memory(deps: Dict[Node, int]): peak_mem = 0 for k, v in deps.items(): if v > 0 and is_phase(k, Phase.BACKWARD) and not all(map(is_inplace, k.users)) and not is_inplace(k): - peak_mem += activation_size(k.meta['saved_tensor']) - if v <= float('-inf') and is_phase(k, Phase.FORWARD): - peak_mem -= activation_size(k.meta['saved_tensor']) + peak_mem += activation_size(k.meta["saved_tensor"]) + if v <= float("-inf") and is_phase(k, Phase.FORWARD): + peak_mem -= activation_size(k.meta["saved_tensor"]) return peak_mem # deps is used to track all the memory dependencies of the graph. @@ -123,19 +122,19 @@ def _peak_memory(deps: Dict[Node, int]): # Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint # the node, `fwd_mem_tmp` can be freed. if is_phase(n, Phase.PLACEHOLDER): - graph_info.fwd_in += n.meta['saved_tensor'] + graph_info.fwd_in += n.meta["saved_tensor"] if is_phase(n, Phase.FORWARD): - graph_info.fwd_tmp += n.meta['saved_tensor'] + graph_info.fwd_tmp += n.meta["saved_tensor"] elif is_phase(n, Phase.BACKWARD): if len(n.users): graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps)) else: # TODO: some of the bwd_mem_out might be model parameters. # basically a backward node without user is a `grad_out` node - graph_info.bwd_mem_out += activation_size(n.meta['saved_tensor']) + graph_info.bwd_mem_out += activation_size(n.meta["saved_tensor"]) for input_n in n.all_input_nodes: if input_n in deps: deps[input_n] -= 1 if deps[input_n] <= 0: - deps[input_n] = float('-inf') + deps[input_n] = float("-inf") return graph_info diff --git a/colossalai/fx/profiler/experimental/constants.py b/colossalai/fx/profiler/experimental/constants.py index 57ff3fd91299..02758e7643af 100644 --- a/colossalai/fx/profiler/experimental/constants.py +++ b/colossalai/fx/profiler/experimental/constants.py @@ -2,7 +2,7 @@ import torch -__all__ = ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD'] +__all__ = ["INPLACE_OPS", "INPLACE_METHOD", "NON_INPLACE_METHOD"] # TODO fill out the inplace ops INPLACE_OPS = [ @@ -20,25 +20,25 @@ # TODO: list all call_methods that are inplace here INPLACE_METHOD = [ - 'transpose', - 'permute', + "transpose", + "permute", # TODO: reshape may return a copy of the data if the data is not contiguous - 'reshape', - 'dim', - 'flatten', - 'size', - 'view', - 'unsqueeze', - 'to', - 'type', - 'flatten', + "reshape", + "dim", + "flatten", + "size", + "view", + "unsqueeze", + "to", + "type", + "flatten", ] # TODO: list all call_methods that are not inplace here NON_INPLACE_METHOD = [ - 'chunk', - 'contiguous', - 'expand', - 'mean', - 'split', + "chunk", + "contiguous", + "expand", + "mean", + "split", ] diff --git a/colossalai/fx/profiler/experimental/profiler.py b/colossalai/fx/profiler/experimental/profiler.py index 5c545260e72b..d890fdb66fc2 100644 --- a/colossalai/fx/profiler/experimental/profiler.py +++ b/colossalai/fx/profiler/experimental/profiler.py @@ -9,7 +9,7 @@ from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD from .registry import meta_profiler_function, meta_profiler_module -__all__ = ['profile_function', 'profile_module', 'profile_method'] +__all__ = ["profile_function", "profile_module", "profile_method"] # this is for compatibility use @@ -42,6 +42,7 @@ class GraphInfo: bwd_mem_tmp (int): See the above illustration. bwd_mem_out (int): See the above illustration. """ + fwd_flop: int = 0 bwd_flop: int = 0 fwd_mem_in: int = 0 @@ -50,8 +51,7 @@ class GraphInfo: bwd_mem_out: int = 0 -CALL_FUNCTION_MSG = \ -""" +CALL_FUNCTION_MSG = """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n from colossalai.fx.profiler.experimental import meta_profiler_function @meta_profiler_function.register(YOUR_FUNCTION) @@ -60,9 +60,8 @@ def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: macs = ... return flops, macs """ -CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' -CALL_MODULE_MSG = \ -""" +CALL_METHOD_MSG = "Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}" +CALL_MODULE_MSG = """ Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n from colossalai.fx.profiler.experimental import meta_profiler_module @meta_profiler_module.register(YOUR_MODULE) @@ -74,7 +73,7 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int @compatibility(is_backward_compatible=True) -def profile_function(target: 'Target') -> Callable: +def profile_function(target: "Target") -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. @@ -92,12 +91,13 @@ def profile_function(target: 'Target') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: assert meta_profiler_function.has(target) or meta_profiler_function.has( - target.__name__), CALL_FUNCTION_MSG.format(target) + target.__name__ + ), CALL_FUNCTION_MSG.format(target) fwd_tmp = 0 fwd_out = 0 out = func(*args, **kwargs) - if target not in INPLACE_OPS and not kwargs.get('inplace', False): + if target not in INPLACE_OPS and not kwargs.get("inplace", False): fwd_out = activation_size(out) if meta_profiler_function.has(target): profiler = meta_profiler_function.get(target) @@ -112,7 +112,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: @compatibility(is_backward_compatible=True) -def profile_method(target: 'Target') -> Callable: +def profile_method(target: "Target") -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. @@ -126,11 +126,12 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: self_obj, *args_tail = args # execute the method and return the result - assert isinstance(target, str), f'{target} instance is not str.' + assert isinstance(target, str), f"{target} instance is not str." out = getattr(self_obj, target)(*args_tail, **kwargs) assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( - target, INPLACE_METHOD, NON_INPLACE_METHOD) + target, INPLACE_METHOD, NON_INPLACE_METHOD + ) # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out) fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out) @@ -161,7 +162,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: fwd_tmp = 0 fwd_out = 0 out = func(*args, **kwargs) - if getattr(module, 'inplace', False): + if getattr(module, "inplace", False): fwd_out = activation_size(out) profiler = meta_profiler_module.get(type(module)) fwd_flop, _ = profiler(module, *args, **kwargs) diff --git a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py index a43aef063e19..c518ec28da41 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py +++ b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_function # TODO: different activation has different FLOPs count, currently unused. diff --git a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py index 8d1c8a8c6877..f1b9bb97c6c6 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py +++ b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py @@ -41,15 +41,15 @@ def _elementwise_flops_compute(input, other): @meta_profiler_function.register(torch.sub) @meta_profiler_function.register(torch.mul) @meta_profiler_function.register(torch.floor_divide) -@meta_profiler_function.register('add') # for built-in op + -@meta_profiler_function.register('iadd') # for built-in op += -@meta_profiler_function.register('eq') # for built-in op = -@meta_profiler_function.register('sub') # for built-in op - -@meta_profiler_function.register('isub') # for built-in op -= -@meta_profiler_function.register('mul') # for built-in op * -@meta_profiler_function.register('imul') # for built-in op *= -@meta_profiler_function.register('floordiv') # for built-in op // -@meta_profiler_function.register('ifloordiv') # for built-in op //= +@meta_profiler_function.register("add") # for built-in op + +@meta_profiler_function.register("iadd") # for built-in op += +@meta_profiler_function.register("eq") # for built-in op = +@meta_profiler_function.register("sub") # for built-in op - +@meta_profiler_function.register("isub") # for built-in op -= +@meta_profiler_function.register("mul") # for built-in op * +@meta_profiler_function.register("imul") # for built-in op *= +@meta_profiler_function.register("floordiv") # for built-in op // +@meta_profiler_function.register("ifloordiv") # for built-in op //= def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: return _elementwise_flops_compute(input, other) @@ -62,7 +62,7 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N @meta_profiler_function.register(torch.matmul) -@meta_profiler_function.register('matmul') # for built-in op @ +@meta_profiler_function.register("matmul") # for built-in op @ @meta_profiler_function.register(torch.Tensor.matmul) def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: macs = reduce(operator.mul, input.shape) * other.shape[-1] @@ -78,13 +78,15 @@ def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.T @meta_profiler_function.register(torch.var_mean) -def torch_var_mean(input: torch.Tensor, - dim: Union[int, Tuple[int, ...]], - unbiased: Optional[bool] = True, - keepdim: Optional[bool] = False, - *, - out: Optional[torch.Tensor] = None) -> Tuple[int, int]: - assert out is None, 'saving to out is not supported yet' +def torch_var_mean( + input: torch.Tensor, + dim: Union[int, Tuple[int, ...]], + unbiased: Optional[bool] = True, + keepdim: Optional[bool] = False, + *, + out: Optional[torch.Tensor] = None, +) -> Tuple[int, int]: + assert out is None, "saving to out is not supported yet" flops = input.numel() * 3 macs = 0 return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py index d6e43d781b8b..1d362015fc8b 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/embedding.py +++ b/colossalai/fx/profiler/experimental/profiler_function/embedding.py @@ -1,5 +1,7 @@ -import torch from typing import Optional + +import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py index 01fe4c871370..ecc578d61b91 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/linear.py +++ b/colossalai/fx/profiler/experimental/profiler_function/linear.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py index c4ea508d70f8..2ad029eda039 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/normalization.py +++ b/colossalai/fx/profiler/experimental/profiler_function/normalization.py @@ -1,5 +1,7 @@ from typing import List, Optional, Tuple + import torch + from ..registry import meta_profiler_function @@ -21,11 +23,13 @@ def torch_nn_func_instancenorm( @meta_profiler_function.register(torch.nn.functional.group_norm) -def torch_nn_func_groupnorm(input: torch.Tensor, - num_groups: int, - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - eps: float = 1e-5) -> Tuple[int, int]: +def torch_nn_func_groupnorm( + input: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> Tuple[int, int]: has_affine = weight is not None flops = input.numel() * (5 if has_affine else 4) macs = 0 diff --git a/colossalai/fx/profiler/experimental/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py index a639f5ee83c1..c91deab906d4 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/pooling.py +++ b/colossalai/fx/profiler/experimental/profiler_function/pooling.py @@ -1,5 +1,7 @@ -from typing import Tuple, Union +from typing import Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py index 1e8561206ba0..58c9889ad98e 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py +++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py @@ -1,6 +1,6 @@ import operator from typing import Any, Tuple -import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py index abdd7ad565ba..67e90fb69acd 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py +++ b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py @@ -1,7 +1,9 @@ -from functools import reduce import operator +from functools import reduce from typing import Any, Optional, Tuple + import torch + from ..registry import meta_profiler_function @@ -43,13 +45,11 @@ def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]: @meta_profiler_function.register(torch.max) -def torch_max(input: torch.Tensor, - dim: int = None, - keepdim: bool = False, - *, - out: Optional[torch.Tensor] = None) -> Tuple[int, int]: +def torch_max( + input: torch.Tensor, dim: int = None, keepdim: bool = False, *, out: Optional[torch.Tensor] = None +) -> Tuple[int, int]: macs = 0 - assert out is None, 'assigning value to out is not supported yet' + assert out is None, "assigning value to out is not supported yet" if dim is not None: shape = list(input.shape) shape.pop(int(dim)) diff --git a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py index 2ebf514ad269..ae065e0c7c17 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py +++ b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module # TODO: different activation has different FLOPs count, currently unused. diff --git a/colossalai/fx/profiler/experimental/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py index 8daf74b232bf..dfaee75e0432 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/attention.py +++ b/colossalai/fx/profiler/experimental/profiler_module/attention.py @@ -1,19 +1,23 @@ from typing import Optional, Tuple + import torch + from ..registry import meta_profiler_module # TODO: This is hard to compute memory cost @meta_profiler_module.register(torch.nn.MultiheadAttention) -def torch_nn_msa(self: torch.nn.MultiheadAttention, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_padding_mask: Optional[torch.Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[torch.Tensor] = None, - average_attn_weights: bool = True) -> Tuple[int, int]: - if getattr(self, 'batch_first', False): +def torch_nn_msa( + self: torch.nn.MultiheadAttention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[torch.Tensor] = None, + average_attn_weights: bool = True, +) -> Tuple[int, int]: + if getattr(self, "batch_first", False): batch_size = query.shape[0] len_idx = 1 else: @@ -44,15 +48,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention, flops += qlen * qdim # Initial projections - flops += 2 * ((qlen * qdim * qdim) # QW - + (klen * kdim * kdim) # KW - + (vlen * vdim * vdim) # VW - ) + flops += 2 * ((qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim)) # QW # KW # VW - macs += ((qlen * qdim * qdim) # QW - + (klen * kdim * kdim) # KW - + (vlen * vdim * vdim) # VW - ) + macs += (qlen * qdim * qdim) + (klen * kdim * kdim) + (vlen * vdim * vdim) # QW # KW # VW if self.in_proj_bias is not None: flops += (qlen + klen + vlen) * qdim @@ -62,13 +60,9 @@ def torch_nn_msa(self: torch.nn.MultiheadAttention, v_head_dim = vdim // num_heads head_flops = ( - 2 * (qlen * klen * qk_head_dim) # QK^T - + (qlen * klen) # softmax - + 2 * (qlen * klen * v_head_dim) # AV + 2 * (qlen * klen * qk_head_dim) + (qlen * klen) + 2 * (qlen * klen * v_head_dim) # QK^T # softmax # AV ) - head_macs = ((qlen * klen * qk_head_dim) # QK^T - + 2 * (qlen * klen * v_head_dim) # AV - ) + head_macs = (qlen * klen * qk_head_dim) + 2 * (qlen * klen * v_head_dim) # QK^T # AV flops += num_heads * head_flops macs += num_heads * head_flops diff --git a/colossalai/fx/profiler/experimental/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py index a4c15b91e611..90e494c77f5b 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/convolution.py +++ b/colossalai/fx/profiler/experimental/profiler_module/convolution.py @@ -17,8 +17,9 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html c_in, l_in = input.shape[-2:] c_out = self.out_channels - l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + l_out = math.floor( + (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, @@ -38,10 +39,12 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in # at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html c_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + h_out = math.floor( + (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, @@ -62,12 +65,15 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in # at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html c_in, d_in, h_in, w_in = input.shape[-4:] c_out = self.out_channels - d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) - w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * - (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + d_out = math.floor( + (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + h_out = math.floor( + (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, @@ -89,8 +95,13 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html c_in, l_in = input.shape[-2:] c_out = self.out_channels - l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + l_out = math.floor( + (l_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, @@ -98,7 +109,7 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups num_elem = reduce( operator.mul, input.shape - ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 + ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 macs = macs_per_elem * num_elem flops = 2 * macs if self.bias is not None: @@ -112,10 +123,20 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html c_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + h_out = math.floor( + (h_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, @@ -136,12 +157,27 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html c_in, d_in, h_in, w_in = input.shape[-4:] c_out = self.out_channels - d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) - w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * - (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + d_out = math.floor( + (d_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + h_out = math.floor( + (h_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[2] + - 2 * self.padding[2] + + self.dilation[2] * (self.kernel_size[2] - 1) + + self.output_padding[2] + + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, diff --git a/colossalai/fx/profiler/experimental/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py index 417e0ed46863..7361239eb1bd 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/dropout.py +++ b/colossalai/fx/profiler/experimental/profiler_module/dropout.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py index e1ffb6f244d2..71fed3196c13 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/linear.py +++ b/colossalai/fx/profiler/experimental/profiler_module/linear.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py index 49e5e6fa5384..5a64e44947b7 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/normalization.py +++ b/colossalai/fx/profiler/experimental/profiler_module/normalization.py @@ -16,8 +16,12 @@ @meta_profiler_module.register(torch.nn.BatchNorm1d) @meta_profiler_module.register(torch.nn.BatchNorm2d) @meta_profiler_module.register(torch.nn.BatchNorm3d) -def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, - torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]: +def torch_nn_normalize( + self: Union[ + torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d + ], + input: torch.Tensor, +) -> Tuple[int, int]: # adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615 has_affine = self.weight is not None if self.training: @@ -30,6 +34,7 @@ def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch try: import apex + meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) diff --git a/colossalai/fx/profiler/experimental/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py index e429ac3eea28..b3b630b2dee9 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/pooling.py +++ b/colossalai/fx/profiler/experimental/profiler_module/pooling.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py index 6e733d6da915..8a4c828dbd27 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/rnn.py +++ b/colossalai/fx/profiler/experimental/profiler_module/rnn.py @@ -1,12 +1,15 @@ -from functools import reduce import operator +from functools import reduce +from typing import Optional, Tuple + import torch + from ..registry import meta_profiler_module -from typing import Optional, Tuple, Union -def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, - w_hh: torch.Tensor) -> Tuple[int, int]: +def _rnn_flops( + flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, w_hh: torch.Tensor +) -> Tuple[int, int]: # copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py # matrix matrix mult ih state and internal state @@ -42,12 +45,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch flops = 0 macs = 0 for i in range(self.num_layers): - w_ih = self.__getattr__('weight_ih_l' + str(i)) - w_hh = self.__getattr__('weight_hh_l' + str(i)) + w_ih = self.__getattr__("weight_ih_l" + str(i)) + w_hh = self.__getattr__("weight_hh_l" + str(i)) flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) if self.bias: - b_ih = self.__getattr__('bias_ih_l' + str(i)) - b_hh = self.__getattr__('bias_hh_l' + str(i)) + b_ih = self.__getattr__("bias_ih_l" + str(i)) + b_hh = self.__getattr__("bias_hh_l" + str(i)) flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) flops *= reduce(operator.mul, input.shape[:2]) macs *= reduce(operator.mul, input.shape[:2]) @@ -63,12 +66,12 @@ def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: flops = 0 macs = 0 - w_ih = self.__getattr__('weight_ih_l') - w_hh = self.__getattr__('weight_hh_l') + w_ih = self.__getattr__("weight_ih_l") + w_hh = self.__getattr__("weight_hh_l") flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) if self.bias: - b_ih = self.__getattr__('bias_ih_l') - b_hh = self.__getattr__('bias_hh_l') + b_ih = self.__getattr__("bias_ih_l") + b_hh = self.__getattr__("bias_hh_l") flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) flops *= input.shape[0] macs *= input.shape[0] diff --git a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py index d3aed874eb10..06be25246a71 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py +++ b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py @@ -1,7 +1,8 @@ -import operator +from typing import Tuple + import torch + from ..registry import meta_profiler_module -from typing import Optional, Tuple, Union @meta_profiler_module.register(torch.nn.Flatten) diff --git a/colossalai/fx/profiler/experimental/registry.py b/colossalai/fx/profiler/experimental/registry.py index 7d73bce321e4..d47129cd2978 100644 --- a/colossalai/fx/profiler/experimental/registry.py +++ b/colossalai/fx/profiler/experimental/registry.py @@ -1,11 +1,9 @@ class ProfilerRegistry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): self.store[source] = func return func @@ -21,5 +19,5 @@ def has(self, source): return source in self.store -meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile') -meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile') +meta_profiler_function = ProfilerRegistry(name="patched_functions_for_meta_profile") +meta_profiler_module = ProfilerRegistry(name="patched_modules_for_meta_profile") diff --git a/colossalai/fx/profiler/experimental/shard_utils.py b/colossalai/fx/profiler/experimental/shard_utils.py index 1e53ed0bf8ec..90e8c3b7cfe4 100644 --- a/colossalai/fx/profiler/experimental/shard_utils.py +++ b/colossalai/fx/profiler/experimental/shard_utils.py @@ -1,8 +1,6 @@ # for PyTorch 1.11 compatibility uses -from typing import Dict, List, Tuple, Union -import torch -from torch.fx import GraphModule, Node +from torch.fx import Node from ..._compatibility import compatibility @@ -19,7 +17,7 @@ def calculate_fwd_in(n: Node) -> bool: Returns: save_fwd_in (bool): the result of `save_fwd_in` """ - return n.meta['save_fwd_in'] + return n.meta["save_fwd_in"] @compatibility(is_backward_compatible=True) @@ -45,4 +43,4 @@ def calculate_fwd_out(n: Node) -> int: Returns: fwd_out (int): the result of `fwd_out` """ - return n.meta['fwd_mem_out'] + return n.meta["fwd_mem_out"] diff --git a/colossalai/fx/profiler/memory_utils.py b/colossalai/fx/profiler/memory_utils.py index 6ccbcb01cdc1..e8eb5f25cb6c 100644 --- a/colossalai/fx/profiler/memory_utils.py +++ b/colossalai/fx/profiler/memory_utils.py @@ -1,11 +1,11 @@ from typing import Dict, List, Tuple, Union import torch -from torch.fx import GraphModule, Node +from torch.fx import Node from .._compatibility import compatibility, is_compatible_with_meta -__all__ = ['activation_size', 'parameter_size', 'is_inplace'] +__all__ = ["activation_size", "parameter_size", "is_inplace"] @compatibility(is_backward_compatible=True) @@ -63,6 +63,7 @@ def is_inplace(n: Node): inplace = n.kwargs.get("inplace", False) if is_compatible_with_meta(): from .constants import ALIAS_ATEN + if n.target in ALIAS_ATEN: inplace = True elif n.op == "call_module": diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index ba090a2ec51b..8fae0f2ecb45 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -173,8 +173,11 @@ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: # Inputs[0] contains the shape of the input. input_shape = inputs[input_arg_index].shape - has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], - 'shape') else inputs[affine_arg_index] + has_affine = ( + inputs[affine_arg_index].shape is not None + if hasattr(inputs[affine_arg_index], "shape") + else inputs[affine_arg_index] + ) assert 2 <= len(input_shape) <= 5, input_shape # 5 is just a rough estimate flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) @@ -188,7 +191,7 @@ def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = N training = inputs[-3] assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" if training: - return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore + return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore has_affine = inputs[1].shape is not None input_shape = reduce(operator.mul, inputs[0].shape) return input_shape * (2 if has_affine else 1) @@ -218,15 +221,16 @@ def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number: def zero_flop_jit(*args): """ - Count flops for zero flop layers. + Count flops for zero flop layers. """ return 0 -if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( - torch.__version__) < version.parse('2.0.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0") and version.parse(torch.__version__) < version.parse( + "2.0.0" +): flop_mapping = { - # gemm, gemv and dot + # gemm, gemv and dot aten.mm.default: matmul_flop_jit, aten.mv.default: matmul_flop_jit, aten.dot.default: matmul_flop_jit, @@ -234,13 +238,11 @@ def zero_flop_jit(*args): aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, aten.baddbmm.default: baddbmm_flop_jit, - - # convolution + # convolution aten.convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit, - - # normalization + # normalization aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit, @@ -249,8 +251,7 @@ def zero_flop_jit(*args): aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), aten.native_group_norm.default: norm_flop_counter(2, 0), aten.native_group_norm_backward.default: norm_flop_counter(2, 0), - - # pooling + # pooling aten.avg_pool1d.default: elementwise_flop_counter(1, 0), aten.avg_pool2d.default: elementwise_flop_counter(1, 0), aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1), @@ -275,7 +276,7 @@ def zero_flop_jit(*args): } elementwise_flop_aten = [ - # basic op + # basic op aten.add.Tensor, aten.add_.Tensor, aten.div.Tensor, @@ -296,8 +297,7 @@ def zero_flop_jit(*args): aten.exp.default, aten.sin.default, aten.cos.default, - - # activation op + # activation op aten.hardswish.default, aten.hardswish_.default, aten.hardswish_backward.default, @@ -320,8 +320,7 @@ def zero_flop_jit(*args): aten.tanh.default, aten.tanh_backward.default, aten.threshold_backward.default, - - # dropout + # dropout aten.native_dropout.default, aten.native_dropout_backward.default, ] @@ -362,7 +361,7 @@ def zero_flop_jit(*args): aten.zero_.default, aten.zeros_like.default, aten.fill_.Scalar, - aten.stack.default + aten.stack.default, ] # yapf: disable for op in zero_flop_aten: diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index c87cd4321d31..97e70db6290e 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -15,7 +15,7 @@ from .opcount import flop_mapping from .tensor import MetaTensor -__all__ = ['profile_function', 'profile_module', 'profile_method'] +__all__ = ["profile_function", "profile_module", "profile_method"] # super-dainiu: this cache should be global, otherwise it cannot # track duplicated tensors between nodes @@ -174,7 +174,6 @@ def _profile_meta(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], G # backward is executed. # Hopefully, this attempt will provide a better estimation of memory. class FlopTensor(MetaTensor): - _node: Node = None def __repr__(self): @@ -186,24 +185,24 @@ def __repr__(self): def __torch_dispatch__(cls, func, types, args=(), kwargs=None): args_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, args) kwargs_node = tree_map(lambda x: x._node if isinstance(x, FlopTensor) else None, kwargs) - node = subgraph.create_node('call_function', func, args_node, kwargs_node) + node = subgraph.create_node("call_function", func, args_node, kwargs_node) out = super().__torch_dispatch__(func, types, args, kwargs) flop_count[phase] += flop_mapping[func](args, normalize_tuple(out)) - node.meta['phase'] = phase + node.meta["phase"] = phase # super-dainiu: in `nn.MultiheadAttention` this weird thing occurs, # i.e. `Phase.PLACEHOLDER` tensors are aliased and saved during # `Phase.FORWARD` if phase == Phase.FORWARD: if all(map(partial(is_phase, phase=Phase.PLACEHOLDER), node.all_input_nodes)) and func in ALIAS_ATEN: - node.meta['phase'] = Phase.PLACEHOLDER + node.meta["phase"] = Phase.PLACEHOLDER # TODO(yby): specify `saved_tensors` for backward memory estimation - node.meta['saved_tensor'] = [] + node.meta["saved_tensor"] = [] if phase == Phase.BACKWARD: - node.meta['saved_tensor'] = normalize_tuple(out) + node.meta["saved_tensor"] = normalize_tuple(out) def wrap(x): if isinstance(x, MetaTensor): @@ -219,11 +218,14 @@ def wrap(x): x = FlopTensor(x) if is_autogradable(x): x.requires_grad_(True) - x._node = subgraph.create_node('placeholder', - 'placeholder', (subgraph._root,), - name=subgraph._graph_namespace.create_name('input', x._tensor)) - x._node.meta['phase'] = Phase.PLACEHOLDER - x._node.meta['saved_tensor'] = [] + x._node = subgraph.create_node( + "placeholder", + "placeholder", + (subgraph._root,), + name=subgraph._graph_namespace.create_name("input", x._tensor), + ) + x._node.meta["phase"] = Phase.PLACEHOLDER + x._node.meta["saved_tensor"] = [] return x # Basically, we need to detach the args and kwargs from the outer graph. @@ -235,7 +237,7 @@ def pack(x): if isinstance(x, FlopTensor) and not x._tensor.data_ptr() in cache: tensor = x._tensor.detach() tensor.data_ptr = x._tensor.data_ptr - x._node.meta['saved_tensor'] += [tensor] + x._node.meta["saved_tensor"] += [tensor] if not do_not_cache: cache.add(x._tensor.data_ptr()) return x @@ -284,7 +286,7 @@ def unwrap(x): @compatibility(is_backward_compatible=True) -def profile_function(target: 'Target', device: str = 'meta') -> Callable: +def profile_function(target: "Target", device: str = "meta") -> Callable: """ Wrap a `call_function` node or `torch.nn.functional` in order to record the memory cost and FLOPs of the execution. @@ -300,7 +302,6 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - # find the grad for parameter in args and kwargs param_size = 0 @@ -316,18 +317,18 @@ def get_param_size(x): # still run the profiling but discard some results regarding `target` global do_not_cache - inplace = kwargs.get('inplace', False) + inplace = kwargs.get("inplace", False) if target in OUTPUT_SAVED_OPS: do_not_cache = True if inplace: do_not_cache = True - kwargs['inplace'] = False - if device == 'meta': + kwargs["inplace"] = False + if device == "meta": out, meta = _profile_meta(func, *args, **kwargs) else: out, meta = _profile_concrete(func, *args, **kwargs) if inplace: - kwargs['inplace'] = True + kwargs["inplace"] = True meta.bwd_mem_tmp = 0 meta.bwd_mem_out = 0 do_not_cache = False @@ -341,7 +342,7 @@ def get_param_size(x): @compatibility(is_backward_compatible=True) -def profile_method(target: 'Target', device: str = 'meta') -> Callable: +def profile_method(target: "Target", device: str = "meta") -> Callable: """ Wrap a `call_method` node record the memory cost and FLOPs of the execution. @@ -349,8 +350,8 @@ def profile_method(target: 'Target', device: str = 'meta') -> Callable: def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # execute the method and return the result - assert isinstance(target, str), f'{target} instance is not str.' - if device == 'meta': + assert isinstance(target, str), f"{target} instance is not str." + if device == "meta": out, meta = _profile_meta(target, *args, **kwargs) else: out, meta = _profile_concrete(target, *args, **kwargs) @@ -360,7 +361,7 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: @compatibility(is_backward_compatible=True) -def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: +def profile_module(module: torch.nn.Module, device: str = "meta") -> Callable: """ Wrap a `call_module` node or `torch.nn` in order to record the memory cost and FLOPs of the execution. @@ -376,7 +377,6 @@ def profile_module(module: torch.nn.Module, device: str = 'meta') -> Callable: """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - # calculate parameter size param_size = parameter_size(module) @@ -384,13 +384,13 @@ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: # still run the profiling but discard some results regarding `module`. global do_not_cache - inplace = getattr(module, 'inplace', False) + inplace = getattr(module, "inplace", False) if type(module) in OUTPUT_SAVED_MOD: do_not_cache = True if inplace: do_not_cache = True module.inplace = False - if device == 'meta': + if device == "meta": out, meta = _profile_meta(func, *args, **kwargs) else: out, meta = _profile_concrete(func, *args, **kwargs) diff --git a/colossalai/fx/profiler/shard_utils.py b/colossalai/fx/profiler/shard_utils.py index 34feefb4336a..75b7c814f05f 100644 --- a/colossalai/fx/profiler/shard_utils.py +++ b/colossalai/fx/profiler/shard_utils.py @@ -59,9 +59,9 @@ def forward(self, input_2): Returns: bool: Whether the node is a ReLU-like node """ - if n.op == 'call_function': + if n.op == "call_function": return n.target in OUTPUT_SAVED_OPS - elif n.op == 'call_module': + elif n.op == "call_module": return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD return False diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 2ee5e5c47750..7c14b48bdaa1 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -1,13 +1,13 @@ import uuid import torch -from torch.types import _bool, _device, _dtype -from torch.utils._pytree import tree_flatten, tree_map +from torch.types import _device +from torch.utils._pytree import tree_map from .._compatibility import compatibility from .constants import ALIAS_ATEN -__all__ = ['MetaTensor'] +__all__ = ["MetaTensor"] def set_data_ptr(x): @@ -43,12 +43,13 @@ def __new__(cls, elem, fake_device=None): storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, - device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')), - requires_grad=elem.requires_grad) # deceive the frontend for aten selections + device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")), + requires_grad=elem.requires_grad, + ) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: - r._tensor = r._tensor.to(torch.device('meta')) + r._tensor = r._tensor.to(torch.device("meta")) # only tensor not on `meta` should be copied to `meta` set_data_ptr(r._tensor) return r @@ -69,15 +70,15 @@ def unwrap(x): x = x._tensor elif isinstance(x, torch.Tensor): fake_device = x.device - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) - if 'device' in kwargs: - fake_device = kwargs['device'] - kwargs['device'] = torch.device('meta') + if "device" in kwargs: + fake_device = kwargs["device"] + kwargs["device"] = torch.device("meta") # run aten for backend=CPU but actually on backend=Meta out = func(*args, **kwargs) @@ -93,7 +94,7 @@ def wrap(x): if isinstance(x, torch.Tensor): nonlocal fake_device if not x.is_meta: - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x return tree_map(wrap, out) @@ -120,18 +121,18 @@ def replace(x): nonlocal fake_device if isinstance(x, str) or isinstance(x, _device): fake_device = x - return 'meta' + return "meta" return x elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) return MetaTensor(elem, fake_device=fake_device) def cpu(self, *args, **kwargs): - if self.device.type == 'cpu': + if self.device.type == "cpu": return self.to(*args, **kwargs) - return self.to(*args, device='cpu', **kwargs) + return self.to(*args, device="cpu", **kwargs) def cuda(self, device=None, non_blocking=False): if device is not None: return self.to(device=device, non_blocking=non_blocking) - return self.to(device='cuda:0', non_blocking=non_blocking) + return self.to(device="cuda:0", non_blocking=non_blocking) diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index 7317072c6298..887832223fd6 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -1,12 +1,11 @@ -import operator -from typing import Any, List, Union +from typing import Any import torch -from torch.fx.proxy import Attribute, Proxy +from torch.fx.proxy import Proxy from colossalai.fx.tracer.meta_patch import meta_patched_function -__all__ = ['ColoProxy'] +__all__ = ["ColoProxy"] class ColoProxy(Proxy): @@ -39,11 +38,12 @@ def has_meta_data(self): return self._meta_data is not None def _assert_meta_data_is_tensor(self): - assert torch.is_tensor( - self._meta_data) and self._meta_data.is_meta, f'Meta data is not a meta tensor for {self.node.name}' + assert ( + torch.is_tensor(self._meta_data) and self._meta_data.is_meta + ), f"Meta data is not a meta tensor for {self.node.name}" def _assert_has_meta_data(self): - assert self._meta_data is not None, f'Meta data is not set for {self.node.name}' + assert self._meta_data is not None, f"Meta data is not set for {self.node.name}" def __len__(self): self._assert_has_meta_data() @@ -62,7 +62,6 @@ def __bool__(self): return self.meta_data def __getattr__(self, k): - return ColoAttribute(self, k) def __contains__(self, key): @@ -92,7 +91,6 @@ def _convert(val): class ColoAttribute(ColoProxy): - def __init__(self, root, attr: str): self.root = root self.attr = attr diff --git a/colossalai/fx/tracer/_meta_trace.py b/colossalai/fx/tracer/_meta_trace.py index 1c5abb81d271..63a7bab654d5 100644 --- a/colossalai/fx/tracer/_meta_trace.py +++ b/colossalai/fx/tracer/_meta_trace.py @@ -39,7 +39,7 @@ class MetaProxy(torch.Tensor): _tensor: torch.Tensor _node: Node - __slots__ = ['_tensor', '_node'] + __slots__ = ["_tensor", "_node"] @staticmethod def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): @@ -51,22 +51,22 @@ def __new__(cls, tensor, fake_device=None, placeholder=False, name=None): dtype=tensor.dtype, layout=tensor.layout, device=fake_device if fake_device is not None else tensor.device, - requires_grad=tensor.requires_grad) # deceive the frontend for aten selections + requires_grad=tensor.requires_grad, + ) # deceive the frontend for aten selections r._tensor = tensor if placeholder: if name is None: - name = 'input' - r._node = graph.create_node('placeholder', - 'placeholder', (graph._root,), - name=namespace.create_name(name, tensor)) + name = "input" + r._node = graph.create_node( + "placeholder", "placeholder", (graph._root,), name=namespace.create_name(name, tensor) + ) # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: - r._tensor = r._tensor.to(torch.device('meta')) + r._tensor = r._tensor.to(torch.device("meta")) return r @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - def unwrap(x): nonlocal fake_device if isinstance(x, MetaProxy): @@ -75,21 +75,21 @@ def unwrap(x): # assert not isinstance(x, MetaProxy) elif isinstance(x, torch.Tensor): fake_device = x.device - x = x.to(torch.device('meta')) + x = x.to(torch.device("meta")) return x def get_node(x): - if isinstance(x, torch.Tensor) and not hasattr(x, '_node'): - x = MetaProxy(x, placeholder=True, name='weight') - return x if not hasattr(x, '_node') else x._node + if isinstance(x, torch.Tensor) and not hasattr(x, "_node"): + x = MetaProxy(x, placeholder=True, name="weight") + return x if not hasattr(x, "_node") else x._node args_node = tree_map(get_node, args) kwargs_node = tree_map(get_node, kwargs) - node = graph.create_node('call_function', func, args_node, kwargs_node) + node = graph.create_node("call_function", func, args_node, kwargs_node) - if 'device' in kwargs: - fake_device = kwargs['device'] - kwargs['device'] = torch.device('meta') + if "device" in kwargs: + fake_device = kwargs["device"] + kwargs["device"] = torch.device("meta") args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) @@ -103,9 +103,12 @@ def wrap(x): if isinstance(x, torch.Tensor): nonlocal fake_device if not x.is_meta: - x = x.to(torch.device('meta')) - return MetaProxy( - x, fake_device=fake_device) if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor') else x + x = x.to(torch.device("meta")) + return ( + MetaProxy(x, fake_device=fake_device) + if isinstance(x, torch.Tensor) and not hasattr(x, "_tensor") + else x + ) def set_node(x): x._node = node @@ -125,9 +128,12 @@ def wrap(x): for tensor in normalize_tuple(out): if is_autogradable(tensor) and tensor.requires_grad: - grad = torch.empty_like(tensor._tensor, device=torch.device('meta')) if isinstance( - tensor, MetaProxy) else torch.empty_like(tensor, device=torch.device('meta')) - torch.autograd.backward(tensor, - MetaProxy(grad, fake_device=tensor.device, placeholder=True), - retain_graph=True) + grad = ( + torch.empty_like(tensor._tensor, device=torch.device("meta")) + if isinstance(tensor, MetaProxy) + else torch.empty_like(tensor, device=torch.device("meta")) + ) + torch.autograd.backward( + tensor, MetaProxy(grad, fake_device=tensor.device, placeholder=True), retain_graph=True + ) return graph diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index e160497a7444..9cf1961d45ff 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -2,10 +2,10 @@ import torch -from ..proxy import ColoAttribute, ColoProxy -from .meta_patch import meta_patched_function, meta_patched_module +from ..proxy import ColoProxy +from .meta_patch import meta_patched_function -__all__ = ['is_element_in_list', 'extract_meta'] +__all__ = ["is_element_in_list", "extract_meta"] def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): @@ -21,7 +21,6 @@ def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]): def extract_meta(*args, **kwargs): - def _convert(val): if isinstance(val, ColoProxy): return val.meta_data diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py index 859a19bf6241..84c09109877e 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addbmm.py @@ -1,7 +1,4 @@ -import operator - import torch -import torch.nn.functional as F from ...registry import bias_addition_function, bias_addition_method from .bias_addition_function import LinearBasedBiasFunc @@ -10,13 +7,12 @@ @bias_addition_method.register(torch.Tensor.addbmm) @bias_addition_function.register(torch.addbmm) class Addbmm(LinearBasedBiasFunc): - def extract_kwargs_from_origin_func(self): kwargs = {} - if 'beta' in self.kwargs: - kwargs['beta'] = self.kwargs['beta'] - if 'alpha' in self.kwargs: - kwargs['alpha'] = self.kwargs['alpha'] + if "beta" in self.kwargs: + kwargs["beta"] = self.kwargs["beta"] + if "alpha" in self.kwargs: + kwargs["alpha"] = self.kwargs["alpha"] return kwargs def create_non_bias_func_proxy(self, input_proxy, other_proxy): @@ -25,7 +21,7 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy): compute the main computation, such as convolution, with bias option banned. """ assert self.substitute_func == torch.bmm - node_kind = 'call_function' + node_kind = "call_function" node_target = self.substitute_func node_args = (input_proxy, other_proxy) @@ -35,10 +31,10 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy): return non_bias_func_proxy def insert_sum_node(self, input_proxy, sum_dims=0): - ''' + """ This method is used to sum the input_proxy through the sum_dims. - ''' - node_kind = 'call_function' + """ + node_kind = "call_function" node_target = torch.sum node_args = (input_proxy, sum_dims) node_kwargs = {} @@ -55,15 +51,15 @@ def generate(self): sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy) kwargs = self.extract_kwargs_from_origin_func() - if 'beta' in kwargs: - beta = kwargs['beta'] + if "beta" in kwargs: + beta = kwargs["beta"] # doing the multiplication with beta if it exists(temp_2 = beta * input) beta_proxy = self.create_mul_node(self.args[0], beta) else: beta_proxy = self.args[0] - if 'alpha' in kwargs: - alpha = kwargs['alpha'] + if "alpha" in kwargs: + alpha = kwargs["alpha"] # doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1) alpha_proxy = self.create_mul_node(alpha, sum_proxy) else: diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py index fe7d8d07aac9..d087b2913005 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/addmm.py @@ -1,7 +1,4 @@ -import operator - import torch -import torch.nn.functional as F from ...registry import bias_addition_function, bias_addition_method from .bias_addition_function import LinearBasedBiasFunc @@ -10,17 +7,16 @@ @bias_addition_method.register(torch.Tensor.addmm) @bias_addition_function.register(torch.addmm) class Addmm(LinearBasedBiasFunc): - def extract_kwargs_from_origin_func(self): kwargs = {} - if 'beta' in self.kwargs: - kwargs['beta'] = self.kwargs['beta'] - if 'alpha' in self.kwargs: - kwargs['alpha'] = self.kwargs['alpha'] + if "beta" in self.kwargs: + kwargs["beta"] = self.kwargs["beta"] + if "alpha" in self.kwargs: + kwargs["alpha"] = self.kwargs["alpha"] return kwargs def transpose_other_operand_for_linear(self, other_proxy): - ''' + """ This method is used to transpose the other operand for linear function. For example: input = torch.rand(3, 4) @@ -30,8 +26,8 @@ def transpose_other_operand_for_linear(self, other_proxy): # To keep the computation graph consistent with the origin computation graph, we need to transpose the m2 # before we call the linear function. new_output = torch.linear(m1, m2.transpose(0, 1)) + input - ''' - node_kind = 'call_function' + """ + node_kind = "call_function" node_target = torch.transpose node_args = (other_proxy, 0, 1) node_kwargs = {} @@ -43,14 +39,14 @@ def generate(self): non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy) kwargs = self.extract_kwargs_from_origin_func() - if 'beta' in kwargs: - beta = kwargs['beta'] + if "beta" in kwargs: + beta = kwargs["beta"] beta_proxy = self.create_mul_node(self.args[0], beta) else: beta_proxy = self.args[0] - if 'alpha' in kwargs: - alpha = kwargs['alpha'] + if "alpha" in kwargs: + alpha = kwargs["alpha"] alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy) else: alpha_proxy = non_bias_linear_func_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py index 8a3786332c08..42178b7b786e 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/bias_addition_function.py @@ -29,7 +29,6 @@ def extract_kwargs_from_origin_func(self): to insert two more operator.mul nodes for the computation graph to compute the final result. """ - pass @abstractmethod def generate(self): @@ -50,7 +49,6 @@ def generate(self): %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {}) """ - pass def create_mul_node(self, input_proxy, coefficent): """ @@ -59,7 +57,7 @@ def create_mul_node(self, input_proxy, coefficent): Therefore, we need to use this method insert two more operator.mul nodes for the computation graph to compute the final result. """ - node_kind = 'call_function' + node_kind = "call_function" node_target = operator.mul node_args = ( input_proxy, @@ -82,7 +80,7 @@ def create_non_bias_func_proxy(self, input_proxy, other_proxy): compute the main computation, such as convolution, with bias option banned. """ assert self.substitute_func == torch.nn.functional.linear - node_kind = 'call_function' + node_kind = "call_function" node_target = self.substitute_func node_args = (input_proxy, other_proxy) @@ -96,7 +94,7 @@ def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): This method is used to create the bias_addition_proxy, the node created by this proxy will compute the sum of non_bias_func result and bias with some reshape operation if needed. """ - bias_add_node_kind = 'call_function' + bias_add_node_kind = "call_function" bias_add_node_target = operator.add bias_add_args = (non_bias_func_proxy, bias_proxy) bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py index e11ec0a364f1..ed060a350739 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_function/linear.py @@ -1,6 +1,3 @@ -import operator - -import torch import torch.nn.functional as F from ...registry import bias_addition_function @@ -9,17 +6,16 @@ @bias_addition_function.register(F.linear) class Linear(LinearBasedBiasFunc): - def extract_kwargs_from_origin_func(self): - assert 'bias' in self.kwargs + assert "bias" in self.kwargs kwargs = {} - if 'bias' in self.kwargs: - kwargs['bias'] = self.kwargs['bias'] + if "bias" in self.kwargs: + kwargs["bias"] = self.kwargs["bias"] return kwargs def generate(self): non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1]) kwargs = self.extract_kwargs_from_origin_func() - bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias']) + bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs["bias"]) return bias_addition_proxy diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py index 591485fdb1ca..19c0e21d7c17 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/bias_addition_module.py @@ -27,8 +27,8 @@ def _create_weight_proxy(self): Note: this function will be invoked during module initializing, you should never call this function. """ - weight_node_kind = 'get_attr' - weight_node_target = self.target + '.weight' + weight_node_kind = "get_attr" + weight_node_target = self.target + ".weight" weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {}) return weight_proxy @@ -39,8 +39,8 @@ def _create_bias_proxy(self): Note: this function will be invoked during module initializing, you should never call this function. """ - bias_node_kind = 'get_attr' - bias_node_target = self.target + '.bias' + bias_node_kind = "get_attr" + bias_node_target = self.target + ".bias" bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {}) return bias_proxy @@ -54,14 +54,13 @@ def extract_kwargs_from_mod(self): considered during module initializing. However, we need to consider those attributes as kwargs in F.conv2d. """ - pass def create_non_bias_func_proxy(self, input_proxy=None): """ This method is used to create the non_bias_func proxy, the node created by this proxy will compute the main computation, such as convolution, with bias option banned. """ - node_kind = 'call_function' + node_kind = "call_function" node_target = self.substitute_func if input_proxy is None: input_proxy = self.args[0] @@ -75,7 +74,7 @@ def create_bias_addition_proxy(self, non_bias_func_proxy, bias_proxy): This method is used to create the bias_addition_proxy, the node created by this proxy will compute the sum of non_bias_func result and bias with some reshape operation if needed. """ - bias_add_node_kind = 'call_function' + bias_add_node_kind = "call_function" bias_add_node_target = operator.add bias_add_args = (non_bias_func_proxy, bias_proxy) bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {}) @@ -100,7 +99,6 @@ def generate(self): %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) """ - pass module_to_func_dict = { diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py index 4b6c82a74f57..812a141c1eab 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py @@ -1,6 +1,5 @@ import torch -import torch.nn.functional as F -from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple +from torch.nn.modules.utils import _pair, _single, _triple from ...registry import bias_addition_module from .bias_addition_module import BiasAdditionModule @@ -10,17 +9,16 @@ @bias_addition_module.register(torch.nn.Conv2d) @bias_addition_module.register(torch.nn.Conv3d) class BiasAdditionConv(BiasAdditionModule): - def extract_kwargs_from_mod(self): root = self.tracer.root conv_module = root.get_submodule(self.target) - kwarg_attributes = ['groups', 'dilation', 'stride'] + kwarg_attributes = ["groups", "dilation", "stride"] non_bias_kwargs = {} for attr_name in kwarg_attributes: if hasattr(conv_module, attr_name): non_bias_kwargs[attr_name] = getattr(conv_module, attr_name) if conv_module.padding_mode != "zeros": - #TODO: non zeros mode requires some extra processing for input + # TODO: non zeros mode requires some extra processing for input conv_type = type(conv_module) if conv_type == "torch.nn.Conv1d": padding_element = _single(0) @@ -28,9 +26,9 @@ def extract_kwargs_from_mod(self): padding_element = _pair(0) elif conv_type == "torch.nn.Conv3d": padding_element = _triple(0) - non_bias_kwargs['padding'] = padding_element + non_bias_kwargs["padding"] = padding_element else: - non_bias_kwargs['padding'] = getattr(conv_module, 'padding') + non_bias_kwargs["padding"] = getattr(conv_module, "padding") return non_bias_kwargs @@ -41,11 +39,12 @@ def create_bias_reshape_proxy(self, dimensions): """ bias_shape = [1] * (dimensions - 1) bias_shape[0] = -1 - bias_reshape_node_kind = 'call_method' - bias_reshape_node_target = 'view' + bias_reshape_node_kind = "call_method" + bias_reshape_node_target = "view" bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape)) - bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target, - bias_reshape_node_args, {}) + bias_reshape_proxy = self.tracer.create_proxy( + bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {} + ) return bias_reshape_proxy def generate(self): diff --git a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py index f6f7b6ddab40..b397f009846c 100644 --- a/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py +++ b/colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/linear.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from ...registry import bias_addition_module from .bias_addition_module import BiasAdditionModule @@ -7,7 +6,6 @@ @bias_addition_module.register(torch.nn.Linear) class BiasAdditionLinear(BiasAdditionModule): - def extract_kwargs_from_mod(self): return {} diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py index 22a67d1ceccc..e6e511b72fbb 100644 --- a/colossalai/fx/tracer/experimental.py +++ b/colossalai/fx/tracer/experimental.py @@ -1,4 +1,3 @@ -import enum import functools import inspect import operator @@ -10,7 +9,7 @@ from torch.utils._pytree import tree_map from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta -from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list +from colossalai.fx.tracer._tracer_utils import is_element_in_list from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict from colossalai.fx.tracer.registry import ( bias_addition_function, @@ -24,31 +23,45 @@ from colossalai.fx.profiler import MetaTensor Target = Union[Callable[..., Any], str] -Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types - List[Any], # actually Argument - Dict[str, Any], # actually Argument - slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing - 'Node',]] -_CScriptMethod = ['add', 'mul', 'sub', 'div'] +Argument = Optional[ + Union[ + Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + "Node", + ] +] +_CScriptMethod = ["add", "mul", "sub", "div"] _TorchNewMethod = [ - "arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor", - "finfo" + "arange", + "zeros", + "zeros_like", + "ones", + "ones_like", + "full", + "full_like", + "empty", + "empty_like", + "eye", + "tensor", + "finfo", ] _TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"] def _truncate_suffix(s: str): import re - return re.sub(r'_\d+$', '', s) + + return re.sub(r"_\d+$", "", s) def default_device(): - return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") @compatibility(is_backward_compatible=False) class ColoProxy(Proxy): - def __init__(self, *args, data=None, **kwargs): super().__init__(*args, **kwargs) self._meta_data = data @@ -100,7 +113,7 @@ def __getattr__(self, k): return ColoAttribute(self, k, getattr(self._meta_data, k, None)) def __setitem__(self, key, value): - proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) + proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {}) proxy.meta_data = self._meta_data return proxy @@ -125,29 +138,28 @@ def ndim(self): @property def device(self): - proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) + proxy = self.tracer.create_proxy("call_function", getattr, (self, "device"), {}) proxy.meta_data = self.meta_data.device return proxy @property def dtype(self): - proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) + proxy = self.tracer.create_proxy("call_function", getattr, (self, "dtype"), {}) proxy.meta_data = self.meta_data.dtype return proxy def to(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs}) + return self.tracer.create_proxy("call_method", "to", (self, *args), {**kwargs}) def cpu(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs}) + return self.tracer.create_proxy("call_method", "cpu", (self, *args), {**kwargs}) def cuda(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs}) + return self.tracer.create_proxy("call_method", "cuda", (self, *args), {**kwargs}) @compatibility(is_backward_compatible=False) class ColoAttribute(ColoProxy): - def __init__(self, root, attr: str, data=None): self.root = root self.attr = attr @@ -160,11 +172,11 @@ def node(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) def __repr__(self): return f"ColoAttribute({self.node.name}, attr={self.attr})" @@ -172,7 +184,6 @@ def __repr__(self): @compatibility(is_backward_compatible=False) class ColoTracer(Tracer): - def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): super().__init__(*args, **kwargs) self._disable_module_getattr = False @@ -184,24 +195,28 @@ def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): self.inside_torch_checkpoint_func = False self.act_ckpt_region_count = 0 - def proxy(self, node: Node) -> 'ColoProxy': + def proxy(self, node: Node) -> "ColoProxy": return ColoProxy(node, self) - def create_proxy(self, - kind: str, - target: Target, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - name: Optional[str] = None, - type_expr: Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - + def create_proxy( + self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], "Proxy"] = None, + ): proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p - if kind == 'placeholder': - proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( - _truncate_suffix(target), None) - elif kind == 'get_attr': + if kind == "placeholder": + proxy.meta_data = ( + self.meta_args[target] + if target in self.meta_args + else self.concrete_args.get(_truncate_suffix(target), None) + ) + elif kind == "get_attr": self._disable_module_getattr = True try: attr_itr = self.root @@ -211,20 +226,21 @@ def create_proxy(self, proxy.meta_data = attr_itr finally: self._disable_module_getattr = False - elif kind == 'call_function': + elif kind == "call_function": proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_method': + elif kind == "call_method": self._disable_module_getattr = True try: - if target == '__call__': + if target == "__call__": proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) + proxy._meta_data = getattr(unwrap_fn(args[0]), target)( + *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs) + ) finally: self._disable_module_getattr = False - elif kind == 'call_module': + elif kind == "call_module": mod = self.root.get_submodule(target) self._disable_module_getattr = True try: @@ -238,14 +254,15 @@ def create_node(self, *args, **kwargs) -> Node: if self.inside_torch_checkpoint_func: # annotate the activation checkpoint module - node.meta['activation_checkpoint'] = self.act_ckpt_region_count + node.meta["activation_checkpoint"] = self.act_ckpt_region_count return node - def trace(self, - root: torch.nn.Module, - concrete_args: Optional[Dict[str, torch.Tensor]] = None, - meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: - + def trace( + self, + root: torch.nn.Module, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, + meta_args: Optional[Dict[str, torch.Tensor]] = None, + ) -> Graph: if meta_args is None: meta_args = {} @@ -260,20 +277,19 @@ def trace(self, # update concrete args with default values non_meta_arg_names = sig_names - meta_arg_names for k, v in sig.parameters.items(): - if k in non_meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default # get non concrete arg names concrete_arg_names = set(concrete_args.keys()) - non_concrete_arg_names = sig_names - concrete_arg_names + sig_names - concrete_arg_names def _check_arg_name_valid(names): success, element = is_element_in_list(names, sig_names) if not success: raise KeyError( - f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function" + ) _check_arg_name_valid(meta_arg_names) _check_arg_name_valid(concrete_arg_names) @@ -292,7 +308,6 @@ def trace_activation_checkpoint(self, enabled: bool): orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction class PatchedCheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): # signal that the current tracing occurs within activation checkpoint part @@ -305,7 +320,8 @@ def forward(ctx, run_function, preserve_rng_state, *args): @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: raise NotImplementedError( - "We do not implement the backward pass as we only trace the forward pass.") + "We do not implement the backward pass as we only trace the forward pass." + ) # override the checkpoint function torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction @@ -356,10 +372,13 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac if attr_val is p: if n not in parameter_proxy_cache: kwargs = {} - if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: - kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else - lambda node: ColoProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] + if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ColoProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None @@ -370,8 +389,9 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac return maybe_buffer_proxy if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), - parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy @@ -389,42 +409,41 @@ def symbolic_trace( if meta_args is not None: root.to(default_device()) wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x - graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, - concrete_args=concrete_args, - meta_args=tree_map(wrap_fn, meta_args)) + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace( + root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args) + ) root.cpu() else: graph = Tracer().trace(root, concrete_args=concrete_args) else: from .tracer import ColoTracer as OrigColoTracer - graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, - concrete_args=concrete_args, - meta_args=meta_args) + + graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace( + root, concrete_args=concrete_args, meta_args=meta_args + ) name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ return ColoGraphModule(root, graph, name) @compatibility(is_backward_compatible=False) class _TorchTensorOverride(object): - def __init__(self, tracer: Tracer): self.overrides = {} self.tracer = tracer def __enter__(self): - def wrap_tensor_method(target): - @functools.wraps(target) def wrapper(*args, **kwargs): is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( - isinstance(p, ColoProxy) for p in kwargs.values()) + isinstance(p, ColoProxy) for p in kwargs.values() + ) if is_proxy: # if the arg is a proxy, then need to record this function called on this proxy # e.g. torch.ones(size) where size is an input proxy self.tracer._disable_module_getattr = True try: - proxy = self.tracer.create_proxy('call_function', target, args, kwargs) + proxy = self.tracer.create_proxy("call_function", target, args, kwargs) finally: self.tracer._disable_module_getattr = False return proxy @@ -446,11 +465,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): setattr(torch, name, orig) -def meta_prop_pass(gm: ColoGraphModule, - root: torch.nn.Module, - meta_args: Optional[Dict[str, Any]] = None, - concrete_args: Optional[Dict[str, torch.Tensor]] = None): - +def meta_prop_pass( + gm: ColoGraphModule, + root: torch.nn.Module, + meta_args: Optional[Dict[str, Any]] = None, + concrete_args: Optional[Dict[str, torch.Tensor]] = None, +): if meta_args is None: meta_args = {} @@ -465,36 +485,36 @@ def meta_prop_pass(gm: ColoGraphModule, # update concrete args with default values non_meta_arg_names = sig_names - meta_arg_names for k, v in sig.parameters.items(): - if k in non_meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default for node in gm.graph.nodes: - node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, - node.kwargs) + node._meta_data = _meta_data_computing( + meta_args, concrete_args, root, node.op, node.target, node.args, node.kwargs + ) def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n - if kind == 'placeholder': + if kind == "placeholder": meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None) - elif kind == 'get_attr': + elif kind == "get_attr": attr_itr = root atoms = target.split(".") for atom in atoms: attr_itr = getattr(attr_itr, atom) meta_out = attr_itr - elif kind == 'call_function': + elif kind == "call_function": meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_method': - if target == '__call__': + elif kind == "call_method": + if target == "__call__": meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) else: if target not in _TensorPropertyMethod: - meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) - elif kind == 'call_module': + meta_out = getattr(unwrap_fn(args[0]), target)( + *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs) + ) + elif kind == "call_module": mod = root.get_submodule(target) meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) else: @@ -603,26 +623,30 @@ def wrap_fn(n): if kind == "call_function": if bias_addition_function.has(target): if target == torch.nn.functional.linear: - if 'bias' in kwargs and kwargs['bias'] is not None: + if "bias" in kwargs and kwargs["bias"] is not None: function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_function.get(target)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) else: function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_function.get(target)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) elif bias_addition_function.has(target.__name__): # use name for some builtin op like @ (matmul) function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_function.get(target.__name__)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) elif kind == "call_method": method = getattr(args_metas[0].__class__, target) if bias_addition_method.has(method): function_to_substitute = method_to_func_dict[method] - handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_method.get(method)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) elif kind == "call_module": # if not hasattr(self, "orig_forward"): @@ -631,8 +655,9 @@ def wrap_fn(n): mod_type = type(mod) if bias_addition_module.has(mod_type) and mod.bias is not None: function_to_substitute = module_to_func_dict[mod_type] - handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, - function_to_substitute) + handle = bias_addition_module.get(mod_type)( + tracer, target, args_proxy, kwargs_proxy, function_to_substitute + ) if handle is not None: handle.generate() diff --git a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py index 12c42514895e..75d7b18a067c 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/activation_function.py @@ -5,4 +5,4 @@ @meta_patched_function.register(torch.nn.functional.relu) def torch_nn_func_relu(input, inplace=False): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py index 042b92c5847a..3475f22e3b19 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -4,7 +4,7 @@ @meta_patched_function.register(torch.matmul) -@meta_patched_function.register('matmul') # for built-in op @ +@meta_patched_function.register("matmul") # for built-in op @ def torch_matmul(input, other, *, out=None): # copied from huggingface.utils.fx d1 = input.dim() @@ -44,8 +44,8 @@ def torch_matmul(input, other, *, out=None): @meta_patched_function.register(torch.abs) def torch_abs(input, *, out=None): - assert out is None, 'out is not supported yet' - return torch.empty(input.shape, device='meta') + assert out is None, "out is not supported yet" + return torch.empty(input.shape, device="meta") @meta_patched_function.register(torch.bmm) @@ -89,7 +89,7 @@ def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): @meta_patched_function.register(torch.var_mean) def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): - assert out is None, 'saving to out is not supported yet' - var = torch.empty(1).squeeze(0).to('meta') - mean = torch.empty(1).squeeze(0).to('meta') + assert out is None, "saving to out is not supported yet" + var = torch.empty(1).squeeze(0).to("meta") + mean = torch.empty(1).squeeze(0).to("meta") return var, mean diff --git a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py index 8500e5c82508..26daf32a2afc 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/convolution.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/convolution.py @@ -8,7 +8,6 @@ def _ntuple(n, name="parse"): - def parse(x): if isinstance(x, collections.abc.Iterable): return tuple(x) @@ -24,21 +23,21 @@ def parse(x): def _extract_kwargs(kwargs): - if 'stride' in kwargs: - stride = kwargs['stride'] + if "stride" in kwargs: + stride = kwargs["stride"] else: stride = 1 # TODO: process str type padding - if 'padding' in kwargs: - padding = kwargs['padding'] + if "padding" in kwargs: + padding = kwargs["padding"] else: padding = 0 - if 'dilation' in kwargs: - dilation = kwargs['dilation'] + if "dilation" in kwargs: + dilation = kwargs["dilation"] else: dilation = 1 - if 'output_padding' in kwargs: - output_padding = kwargs['output_padding'] + if "output_padding" in kwargs: + output_padding = kwargs["output_padding"] else: output_padding = 0 @@ -61,7 +60,7 @@ def torch_nn_functional_conv1d(input, weight, **kwargs): c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv2d) @@ -82,7 +81,7 @@ def torch_nn_functional_conv2d(input, weight, **kwargs): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv3d) @@ -105,7 +104,7 @@ def torch_nn_functional_conv3d(input, weight, **kwargs): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv_transpose1d) @@ -120,13 +119,14 @@ def torch_nn_functional_convtranspose1d(input, weight, **kwargs): kernel_size = weight.shape[2:] l_in = input.shape[-1] c_out = weight.shape[1] - l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + - output_padding[0] + 1) + l_out = math.floor( + (l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv_transpose2d) @@ -141,16 +141,18 @@ def torch_nn_functional_convtranspose2d(input, weight, **kwargs): kernel_size = weight.shape[2:] h_in, w_in = input.shape[-2:] c_out = weight.shape[1] - h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + - output_padding[0] + 1) - w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + - output_padding[1] + 1) + h_out = math.floor( + (h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 + ) + w_out = math.floor( + (w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_function.register(torch.nn.functional.conv_transpose3d) @@ -165,16 +167,19 @@ def torch_nn_functional_convtranspose3d(input, weight, **kwargs): kernel_size = weight.shape[2:] d_in, h_in, w_in = input.shape[-3:] c_out = weight.shape[1] - d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + - output_padding[0] + 1) - h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + - output_padding[1] + 1) - w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + - output_padding[2] + 1) + d_out = math.floor( + (d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 + ) + h_out = math.floor( + (h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1 + ) + w_out = math.floor( + (w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + output_padding[2] + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py index 6d8d864ea29a..27a79f18590a 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/embedding.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/embedding.py @@ -4,11 +4,7 @@ @meta_patched_function.register(torch.nn.functional.embedding) -def torch_nn_functional_embedding(input, - weight, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False): +def torch_nn_functional_embedding( + input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False +): return torch.empty(*input.shape, weight.shape[-1], device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py index e9e7eda6159c..8a6214990830 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/normalization.py @@ -5,16 +5,11 @@ @meta_patched_function.register(torch.nn.functional.layer_norm) def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") @meta_patched_function.register(torch.nn.functional.batch_norm) -def torch_nn_func_batchnorm(input, - running_mean, - running_var, - weight=None, - bias=None, - training=False, - momentum=0.1, - eps=1e-05): - return torch.empty(input.shape, device='meta') +def torch_nn_func_batchnorm( + input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05 +): + return torch.empty(input.shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py index 4c171cb10991..7642934a409b 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/python_ops.py @@ -19,9 +19,9 @@ def to_concrete(t): return t def _slice_convert(slice_obj): - attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step} + attrs = {"start": slice_obj.start, "stop": slice_obj.stop, "step": slice_obj.step} new_attrs = _slice_attr_convert(attrs) - attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step']) + attr_dict_to_tuple = (new_attrs["start"], new_attrs["stop"], new_attrs["step"]) return slice(*attr_dict_to_tuple) def _slice_attr_convert(attrs): diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index b14ff10ce137..c61e1c4dc9e1 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -105,14 +105,15 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None): shapes = [t.shape for t in tensors] shape = list(shapes[0]) concatenated_dim = sum(shape[dim] for shape in shapes) - final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] + final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :] return torch.empty(final_shape, device="meta") @meta_patched_function.register(torch.repeat_interleave) def torch_repeat_interleave(input, repeats, dim=None, output_size=None): - assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \ - "Argument 'repeats' should be of type 'torch.Tensor' or 'int'" + assert isinstance(repeats, int) or isinstance( + repeats, torch.Tensor + ), "Argument 'repeats' should be of type 'torch.Tensor' or 'int'" shape = list(input.shape) if dim is not None else [input.numel()] dim = dim if dim is not None else 0 @@ -132,36 +133,36 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None) @meta_patched_function.register(torch.roll) def torch_roll(input, shifts, dims=None): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") @meta_patched_function.register(torch.full) def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): - assert out is None, 'assigning result to out is not supported yet' - return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad) + assert out is None, "assigning result to out is not supported yet" + return torch.empty(size, device="meta", dtype=dtype, layout=layout, requires_grad=requires_grad) @meta_patched_function.register(torch.max) def torch_max(input, dim=None, keepdim=False, *, out=None): - assert out is None, 'assigning value to out is not supported yet' + assert out is None, "assigning value to out is not supported yet" if dim is not None: if isinstance(dim, int): shape = list(input.shape) shape.pop(dim) if keepdim: shape.insert(dim, 1) - return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape, - device='meta', - dtype=input.dtype) + return torch.empty(shape, device="meta", dtype=input.dtype), torch.empty( + shape, device="meta", dtype=input.dtype + ) elif isinstance(dim, torch.Tensor): # when dim is a 0D or 1D tensor, it will maintain the same shape num_dims = dim.dim() if num_dims in [0, 1]: - return torch.empty_like(input, device='meta') + return torch.empty_like(input, device="meta") else: raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions") else: - return torch.empty([], device='meta', dtype=input.dtype) + return torch.empty([], device="meta", dtype=input.dtype) @meta_patched_function.register(torch.Tensor.cpu) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py index e28e52585fff..3f40ec2a67ee 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py @@ -4,4 +4,4 @@ from .linear import * from .normalization import * from .pooling import * -from .rnn import * \ No newline at end of file +from .rnn import * diff --git a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py index d03da6588c1c..aa2ede187d37 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/activation_function.py @@ -10,4 +10,4 @@ @meta_patched_module.register(torch.nn.ReLU6) @meta_patched_module.register(torch.nn.PReLU) def torch_nn_non_linear_act(self, input): - return torch.empty(input.shape, device='meta') + return torch.empty(input.shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py index cf9f3487aac9..35173a68a0be 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/convolution.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/convolution.py @@ -11,13 +11,14 @@ def torch_nn_conv1d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d l_in = input.shape[-1] c_out = self.out_channels - l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + l_out = math.floor( + (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.Conv2d) @@ -26,16 +27,18 @@ def torch_nn_conv2d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d h_in, w_in = input.shape[-2:] c_out = self.out_channels - h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + h_out = math.floor( + (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.Conv3d) @@ -44,19 +47,22 @@ def torch_nn_conv3d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d d_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * - (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) - h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * - (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) - w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * - (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + d_out = math.floor( + (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1 + ) + h_out = math.floor( + (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1 + ) + w_out = math.floor( + (w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.ConvTranspose1d) @@ -65,13 +71,18 @@ def torch_nn_convtranspose1d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html l_in = input.shape[-1] c_out = self.out_channels - l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + l_out = math.floor( + (l_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) result_shape = input.shape[:-2] + ( c_out, l_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.ConvTranspose2d) @@ -80,16 +91,26 @@ def torch_nn_convtranspose2d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html h_in, w_in = input.shape[-2:] c_out = self.out_channels - h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + h_out = math.floor( + (h_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) result_shape = input.shape[:-3] + ( c_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.ConvTranspose3d) @@ -98,16 +119,31 @@ def torch_nn_convtranspose3d(self, input): # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html d_in, h_in, w_in = input.shape[-3:] c_out = self.out_channels - d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * - (self.kernel_size[0] - 1) + self.output_padding[0] + 1) - h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - (self.kernel_size[1] - 1) + self.output_padding[1] + 1) - w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * - (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + d_out = math.floor( + (d_in - 1) * self.stride[0] + - 2 * self.padding[0] + + self.dilation[0] * (self.kernel_size[0] - 1) + + self.output_padding[0] + + 1 + ) + h_out = math.floor( + (h_in - 1) * self.stride[1] + - 2 * self.padding[1] + + self.dilation[1] * (self.kernel_size[1] - 1) + + self.output_padding[1] + + 1 + ) + w_out = math.floor( + (w_in - 1) * self.stride[2] + - 2 * self.padding[2] + + self.dilation[2] * (self.kernel_size[2] - 1) + + self.output_padding[2] + + 1 + ) result_shape = input.shape[:-4] + ( c_out, d_out, h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py index 999e33b17c1c..f28647e9caa5 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/embedding.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/embedding.py @@ -6,4 +6,4 @@ @meta_patched_module.register(torch.nn.Embedding) def torch_nn_embedding(self, input): result_shape = input.shape + (self.embedding_dim,) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/linear.py b/colossalai/fx/tracer/meta_patch/patched_module/linear.py index 56f13bf97532..97e6b0e96e83 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/linear.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/linear.py @@ -6,5 +6,7 @@ @meta_patched_module.register(torch.nn.Linear) def torch_nn_linear(self, input): last_dim = input.shape[-1] - assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' + assert ( + last_dim == self.in_features + ), f"Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch" return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py index c21ff64cf3de..198e72e342b1 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/normalization.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/normalization.py @@ -23,6 +23,7 @@ def torch_nn_normalize(self, input): try: import apex + meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) diff --git a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py index 7ce23fbf7ac9..450586d02f8f 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/pooling.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/pooling.py @@ -8,7 +8,7 @@ @meta_patched_module.register(torch.nn.AvgPool1d) def torch_nn_avgpool1d(self, input): num_dim = input.dim() - assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions" l_in = input.shape[-1] @@ -25,13 +25,13 @@ def _convert_int_to_list(item): l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) result_shape = tuple(input.shape[:-1]) + (l_out,) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AvgPool2d) def torch_nn_avgpool2d(self, input): num_dim = input.dim() - assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions" h_in, w_in = input.shape[-2:] @@ -52,13 +52,13 @@ def _convert_int_to_list(item): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AvgPool3d) def torch_nn_avgpool3d(self, input): num_dim = input.dim() - assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions" d_in, h_in, w_in = input.shape[-3:] @@ -81,13 +81,13 @@ def _convert_int_to_list(item): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.MaxPool1d) def torch_nn_maxpool1d(self, input): num_dim = input.dim() - assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' + assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions" l_in = input.shape[-1] @@ -105,13 +105,13 @@ def _convert_int_to_list(item): l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) result_shape = tuple(input.shape[:-1]) + (l_out,) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.MaxPool2d) def torch_nn_maxpool2d(self, input): num_dim = input.dim() - assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' + assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions" h_in, w_in = input.shape[-2:] @@ -133,13 +133,13 @@ def _convert_int_to_list(item): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.MaxPool3d) def torch_nn_maxpool3d(self, input): num_dim = input.dim() - assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' + assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions" d_in, h_in, w_in = input.shape[-3:] @@ -163,7 +163,7 @@ def _convert_int_to_list(item): h_out, w_out, ) - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AdaptiveAvgPool1d) @@ -175,7 +175,7 @@ def torch_nn_adapative_pooling_1d(self, input): else: output_size = self.output_size result_shape = tuple(input.shape[:-1]) + output_size - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AdaptiveAvgPool2d) @@ -187,7 +187,7 @@ def torch_nn_adapative_pooling_2d(self, input): else: output_size = self.output_size result_shape = tuple(input.shape[:-2]) + output_size - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") @meta_patched_module.register(torch.nn.AdaptiveAvgPool3d) @@ -199,4 +199,4 @@ def torch_nn_adapative_pooling_3d(self, input): else: output_size = self.output_size result_shape = tuple(input.shape[:-3]) + output_size - return torch.empty(result_shape, device='meta') + return torch.empty(result_shape, device="meta") diff --git a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py index ee15ca34162e..bfb7ed171186 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/rnn.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/rnn.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from ...registry import meta_patched_module @@ -8,9 +6,11 @@ @meta_patched_module.register(torch.nn.GRU) @meta_patched_module.register(torch.nn.RNN) def torch_nn_rnn(self, input, hx): - assert input.shape[ - -1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch' - assert hx.shape[ - -1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch' + assert ( + input.shape[-1] == self.input_size + ), f"Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch" + assert ( + hx.shape[-1] == self.hidden_size + ), f"Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch" d = 2 if self.bidirectional else 1 return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx diff --git a/colossalai/fx/tracer/registry.py b/colossalai/fx/tracer/registry.py index 12fc6de73d44..80b3868bb4fe 100644 --- a/colossalai/fx/tracer/registry.py +++ b/colossalai/fx/tracer/registry.py @@ -1,11 +1,9 @@ class PatchRegistry: - def __init__(self, name): self.name = name self.store = {} def register(self, source): - def wrapper(func): self.store[source] = func return func @@ -21,8 +19,8 @@ def has(self, source): return source in self.store -meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution') -meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution') -bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition') -bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition') -bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition') +meta_patched_function = PatchRegistry(name="patched_functions_for_meta_execution") +meta_patched_module = PatchRegistry(name="patched_modules_for_meta_execution") +bias_addition_function = PatchRegistry(name="patched_function_for_bias_addition") +bias_addition_module = PatchRegistry(name="patched_module_for_bias_addition") +bias_addition_method = PatchRegistry(name="patched_method_for_bias_addition") diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 28965a1b8e74..d9cb587b5d39 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -29,7 +29,7 @@ meta_patched_module, ) -__all__ = ['ColoTracer'] +__all__ = ["ColoTracer"] class TracerType(enum.Enum): @@ -103,7 +103,7 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr if kind == "call_function": if bias_addition_function.has(target): if target == torch.nn.functional.linear: - if 'bias' in kwargs and kwargs['bias'] is not None: + if "bias" in kwargs and kwargs["bias"] is not None: function_to_substitute = func_to_func_dict[target] handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) else: @@ -160,22 +160,27 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac if n not in parameter_proxy_cache: kwargs = {} if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: - kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else - lambda node: ParameterProxy(self, node, n, attr_val)) - val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy(self, node, n, attr_val) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] parameter_proxy_cache[n] = val_proxy return parameter_proxy_cache[n] return None if isinstance(attr_val, torch.nn.Parameter): - maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), - parameter_proxy_cache) + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) if maybe_parameter_proxy is not None: return maybe_parameter_proxy if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): - maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), - parameter_proxy_cache) + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) if maybe_buffer_proxy is not None: return maybe_buffer_proxy @@ -190,7 +195,7 @@ def call_module(self, m, forward, args, kwargs): # if a customized or third-party module like apex.normalization.FusedRMSNorm is patched, # we should treat it as leaf module as well if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name): - return self.create_proxy('call_module', module_qualified_name, args, kwargs) + return self.create_proxy("call_module", module_qualified_name, args, kwargs) else: return forward(*args, **kwargs) @@ -211,7 +216,6 @@ def _configure_tracer_type(self, tracer_type: TracerType): raise ValueError(f"Unrecognized tracer type {tracer_type}") def _meta_data_computing(self, kind, target, args, kwargs): - if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: meta_out = self.meta_args[target] return meta_out @@ -235,8 +239,9 @@ def _meta_data_computing(self, kind, target, args, kwargs): # Therefore, I need to record the nn.parameter.Parameter attribute for the operation # added by the bias addition manipulation following the get_attr node. convert_to_parameter = False - if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0], - torch.nn.parameter.Parameter): + if target in (torch.transpose, torch.reshape) and isinstance( + args_metas[0], torch.nn.parameter.Parameter + ): convert_to_parameter = True # fetch patched function if meta_patched_function.has(target): @@ -309,10 +314,12 @@ def _meta_data_computing(self, kind, target, args, kwargs): return meta_out - def trace(self, - root: nn.Module, - concrete_args: Optional[Dict[str, Tensor]] = None, - meta_args: Optional[Dict[str, Tensor]] = None) -> Graph: + def trace( + self, + root: nn.Module, + concrete_args: Optional[Dict[str, Tensor]] = None, + meta_args: Optional[Dict[str, Tensor]] = None, + ) -> Graph: """ Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow. @@ -341,9 +348,7 @@ def trace(self, # update concrete args with default values non_meta_arg_names = sig_names - meta_arg_names for k, v in sig.parameters.items(): - if k in non_meta_arg_names and \ - k not in concrete_args and \ - v.default is not inspect.Parameter.empty: + if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty: concrete_args[k] = v.default # get non concrete arg names @@ -354,7 +359,8 @@ def _check_arg_name_valid(names): success, element = is_element_in_list(names, sig_names) if not success: raise KeyError( - f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") + f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function" + ) _check_arg_name_valid(meta_arg_names) _check_arg_name_valid(concrete_arg_names) @@ -363,11 +369,13 @@ def _check_arg_name_valid(names): def _check_kwargs(kwargs, should_be_meta: bool): for k, v in kwargs.items(): if not should_be_meta: - assert not torch.is_tensor(v) or not v.is_meta, \ - f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer' + assert ( + not torch.is_tensor(v) or not v.is_meta + ), f"Expected the {k} not to be a meta tensor, please check the args passed to the tracer" else: - assert v.is_meta == should_be_meta, \ - f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' + assert ( + v.is_meta == should_be_meta + ), f"Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer" _check_kwargs(concrete_args, should_be_meta=False) _check_kwargs(meta_args, should_be_meta=True) @@ -442,7 +450,6 @@ def trace_activation_checkpoint(self, enabled: bool): orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction class PatchedCheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): # signal that the current tracing occurs within activation checkpoint part @@ -455,7 +462,8 @@ def forward(ctx, run_function, preserve_rng_state, *args): @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: raise NotImplementedError( - "We do not implement the backward pass as we only trace the forward pass.") + "We do not implement the backward pass as we only trace the forward pass." + ) # override the checkpoint function torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction @@ -470,12 +478,11 @@ def create_node(self, *args, **kwargs) -> Node: if self.inside_torch_checkpoint_func: # annotate the activation checkpoint module - node.meta['activation_checkpoint'] = self.act_ckpt_region_count + node.meta["activation_checkpoint"] = self.act_ckpt_region_count return node def wrap_tensor_constructor_method(target): - def look_for_proxy(*args, **kwargs): # find in pos vars for arg in args: @@ -518,12 +525,10 @@ def wrapper(*args, **kwargs): for method in magic_methods: def _scope(method): - def impl(*args, **kwargs): - tracer = args[0].tracer target = getattr(operator, method) - proxy = tracer.create_proxy('call_function', target, args, kwargs) + proxy = tracer.create_proxy("call_function", target, args, kwargs) if not isinstance(proxy, ColoProxy): meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) proxy = ColoProxy(proxy.node) @@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name): def impl(self, rhs): target = getattr(operator, orig_method_name) - proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {}) + proxy = self.tracer.create_proxy("call_function", target, (rhs, self), {}) if not isinstance(proxy, ColoProxy): meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {}) proxy = ColoProxy(proxy.node) diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py index e467b4c73e6b..112b920ba158 100644 --- a/colossalai/inference/tensor_parallel/__init__.py +++ b/colossalai/inference/tensor_parallel/__init__.py @@ -1,4 +1,4 @@ from .engine import TPInferEngine from .kvcache_manager import MemoryManager -__all__ = ['MemoryManager', 'TPInferEngine'] +__all__ = ["MemoryManager", "TPInferEngine"] diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index 2bff9317283e..ac185f1b6529 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -1,6 +1,5 @@ # might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later from dataclasses import dataclass -from typing import Any import torch @@ -31,7 +30,7 @@ class BatchInferState: decode_mem_index: torch.Tensor = None decode_layer_id: int = None - device: torch.device = torch.device('cuda') + device: torch.device = torch.device("cuda") @property def total_token_num(self): @@ -43,13 +42,15 @@ def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager @staticmethod - def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, - alloc_mem_index: torch.Tensor): - """ in-place update block loc mapping based on the sequence length of the inputs in current bath""" + def init_block_loc( + b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor + ): + """in-place update block loc mapping based on the sequence length of the inputs in current bath""" start_index = 0 seq_len_numpy = seq_len.cpu().numpy() for i, cur_seq_len in enumerate(seq_len_numpy): - b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + - cur_seq_len] + b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[ + start_index : start_index + cur_seq_len + ] start_index += cur_seq_len return diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index a5a55702ade0..1335f13d66b8 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch import torch.nn as nn @@ -15,7 +15,7 @@ DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 -_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM'] +_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"] class TPInferEngine: @@ -39,14 +39,16 @@ class TPInferEngine: >>> outputs = infer_engine.generate(input_ids, **generate_kwargs) """ - def __init__(self, - model: nn.Module, - shard_config: ShardConfig, - max_batch_size: int, - max_input_len: int, - max_output_len: int, - dtype: torch.dtype = torch.float16, - device: str = 'cuda') -> None: + def __init__( + self, + model: nn.Module, + shard_config: ShardConfig, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + ) -> None: self.max_batch_size = max_batch_size self.max_input_len = max_input_len self.max_output_len = max_output_len @@ -63,7 +65,7 @@ def __init__(self, self.head_num = model.config.num_attention_heads self.layer_num = model.config.num_hidden_layers - self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config + self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None self.shard_config = shard_config @@ -74,9 +76,10 @@ def __init__(self, def _init_manager(self) -> None: assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig" assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" - self.head_num //= self.tp_size # update sharded number of heads - self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim, - self.layer_num) + self.head_num //= self.tp_size # update sharded number of heads + self.cache_manager = MemoryManager( + self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num + ) def _optimize_model(self, model: nn.Module) -> None: """ @@ -90,7 +93,7 @@ def _optimize_model(self, model: nn.Module) -> None: self._shard_model_by(shardformer, model) def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig: - """ Prepare the engine with a given ShardConfig. + """Prepare the engine with a given ShardConfig. Args: shard_config (ShardConfig): shard config given to specify settings of the engine. @@ -118,9 +121,10 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) return shard_config def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: - """ Shard original model by the given ShardFormer and store the sharded model. """ - assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \ - "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" + """Shard original model by the given ShardFormer and store the sharded model.""" + assert ( + self.tp_size == shardformer.shard_config.tensor_parallel_size + ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." policy = get_autopolicy(model, inference_only=True) @@ -147,7 +151,7 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], for t in input_tokens: if torch.is_tensor(input_tokens[t]): input_tokens[t] = input_tokens[t].cuda() - if 'max_new_tokens' not in generate_kwargs: + if "max_new_tokens" not in generate_kwargs: generate_kwargs.update(max_new_tokens=self.max_output_len) return self._generate_by_set_infer_state(input_tokens, **generate_kwargs) @@ -176,18 +180,18 @@ def prepare_batch_state(self, inputs) -> BatchInferState: attention_mask = None if isinstance(inputs, (BatchEncoding, dict)): - input_ids_list = inputs['input_ids'] - attention_mask = inputs['attention_mask'] + input_ids_list = inputs["input_ids"] + attention_mask = inputs["attention_mask"] else: input_ids_list = inputs - if isinstance(input_ids_list[0], int): # for a single input + if isinstance(input_ids_list[0], int): # for a single input input_ids_list = [input_ids_list] attention_mask = [attention_mask] if attention_mask is not None else attention_mask batch_size = len(input_ids_list) - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda') - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda') + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 max_len_in_batch = -1 @@ -210,10 +214,10 @@ def prepare_batch_state(self, inputs) -> BatchInferState: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda') + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") batch_infer_state = BatchInferState(batch_size, max_len_in_batch) - batch_infer_state.seq_len = seq_lengths.to('cuda') - batch_infer_state.start_loc = seq_start_indexes.to('cuda') + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") batch_infer_state.block_loc = block_loc batch_infer_state.decode_layer_id = 0 batch_infer_state.past_key_values_len = 0 @@ -248,7 +252,7 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch model = self.model.model elif isinstance(model, BloomForCausalLM): model = self.model.transformer - setattr(model, 'infer_state', batch_infer_state) + setattr(model, "infer_state", batch_infer_state) outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False) @@ -262,14 +266,15 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch # as an arg into model.forward. # It requires rewriting model generate and replacing model forward. @torch.no_grad() - def _generate_by_pass_infer_state(self, - input_tokens, - max_out_length: int, - generation_config: Optional[GenerationConfig] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: - + def _generate_by_pass_infer_state( + self, + input_tokens, + max_out_length: int, + generation_config: Optional[GenerationConfig] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + **model_kwargs, + ) -> torch.Tensor: raise NotImplementedError("generate by passing BatchInferState is not implemented.") # might want to use in rewritten generate method: use after model.forward diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index 274c01841279..e74a3a491a7b 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -19,13 +19,15 @@ class MemoryManager: device: device used to store the key and value cache """ - def __init__(self, - size: int, - dtype: torch.dtype, - head_num: int, - head_dim: int, - layer_num: int, - device: torch.device = torch.device('cuda')): + def __init__( + self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + device: torch.device = torch.device("cuda"), + ): self.logger = logging.get_logger(__name__) self.available_size = size self.past_key_values_length = 0 @@ -33,13 +35,13 @@ def __init__(self, self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) def _init_mem_states(self, size, device): - """ Initialize tensors used to manage memory states """ + """Initialize tensors used to manage memory states""" self.mem_state = torch.ones((size,), dtype=torch.bool, device=device) self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device) self.indexes = torch.arange(0, size, dtype=torch.long, device=device) def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): - """ Initialize key buffer and value buffer on specified device """ + """Initialize key buffer and value buffer on specified device""" self.key_buffer = [ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num) ] @@ -49,10 +51,9 @@ def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num): @torch.no_grad() def alloc(self, required_size): - """ allocate space of required_size by providing indexes representing available physical spaces """ + """allocate space of required_size by providing indexes representing available physical spaces""" if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " - f"left_size {self.available_size}") + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") return None torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1) @@ -63,23 +64,25 @@ def alloc(self, required_size): @torch.no_grad() def alloc_contiguous(self, required_size): - """ allocate contiguous space of required_size """ + """allocate contiguous space of required_size""" if required_size > self.available_size: - self.logger.warning(f"No enough cache: required_size {required_size} " - f"left_size {self.available_size}") + self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}") return None torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum) sum_size = len(self.mem_cum_sum) - loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size + - 1] + self.mem_state[0:sum_size - - required_size + 1] - can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size] + loc_sums = ( + self.mem_cum_sum[required_size - 1 :] + - self.mem_cum_sum[0 : sum_size - required_size + 1] + + self.mem_state[0 : sum_size - required_size + 1] + ) + can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size] if can_used_loc.shape[0] == 0: - self.logger.info(f"No enough contiguous cache: required_size {required_size} " - f"left_size {self.available_size}") + self.logger.info( + f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}" + ) return None start_loc = can_used_loc[0] - select_index = self.indexes[start_loc:start_loc + required_size] + select_index = self.indexes[start_loc : start_loc + required_size] self.mem_state[select_index] = 0 self.available_size -= len(select_index) start = start_loc.item() @@ -88,13 +91,13 @@ def alloc_contiguous(self, required_size): @torch.no_grad() def free(self, free_index): - """ free memory by updating memory states based on given indexes """ + """free memory by updating memory states based on given indexes""" self.available_size += free_index.shape[0] self.mem_state[free_index] = 1 @torch.no_grad() def free_all(self): - """ free all memory by updating memory states """ + """free all memory by updating memory states""" self.available_size = len(self.mem_state) self.mem_state[:] = 1 self.past_key_values_length = 0 diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py index 7a98b033f37e..27cec5452ece 100644 --- a/colossalai/inference/tensor_parallel/modeling/__init__.py +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -1,4 +1,4 @@ from .bloom import BloomInferenceForwards from .llama import LlamaInferenceForwards -__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards'] +__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"] diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index ba5eadc92be8..27a26caabefa 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -1,6 +1,6 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.distributed as dist @@ -31,17 +31,17 @@ def generate_alibi(n_head, dtype=torch.float16): """ def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) return [start * start**i for i in range(n)] def get_slopes(n): if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2**math.floor(math.log2(n)) + closest_power_of_2 = 2 ** math.floor(math.log2(n)) slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) slopes_double = get_slopes(2 * closest_power_of_2) - slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2] + slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2] return slopes_combined slopes = get_slopes(n_head) @@ -72,7 +72,6 @@ def bloom_model_forward( infer_state: Optional[BatchInferState] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - logger = logging.get_logger(__name__) if deprecated_arguments.pop("position_ids", False) is not False: @@ -86,8 +85,9 @@ def bloom_model_forward( raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -122,14 +122,15 @@ def bloom_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # NOTE determine if BatchInferState is passed in via arg # if not, get the attr binded to the model # We might wantto remove setattr later if infer_state is None: - assert hasattr(self, 'infer_state') + assert hasattr(self, "infer_state") infer_state = self.infer_state # Compute alibi tensor: check build_alibi_tensor documentation @@ -146,10 +147,11 @@ def bloom_model_forward( if use_cache and seq_length != 1: # prefill stage - infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, - infer_state.context_mem_index) + BatchInferState.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) else: infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) @@ -182,8 +184,11 @@ def bloom_model_forward( # alibi = generate_alibi(self.num_heads).contiguous().cuda() tp_size = dist.get_world_size() curr_tp_rank = dist.get_rank() - alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) * - self.num_heads].cuda() + alibi = ( + generate_alibi(self.num_heads * tp_size) + .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads] + .cuda() + ) causal_mask = self._prepare_attn_mask( attention_mask, input_shape=(batch_size, seq_length), @@ -197,7 +202,6 @@ def bloom_model_forward( if self.gradient_checkpointing and self.training: # NOTE: currently our KV cache manager does not handle this condition def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) @@ -250,32 +254,34 @@ def custom_forward(*inputs): return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, # should always be (None, None, ..., None) + past_key_values=presents, # should always be (None, None, ..., None) hidden_states=all_hidden_states, attentions=all_self_attentions, ) @staticmethod - def bloom_for_causal_lm_forward(self: BloomForCausalLM, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: Optional[BatchInferState] = None, - **deprecated_arguments): + def bloom_for_causal_lm_forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, + **deprecated_arguments, + ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ - logger = logging.get_logger(__name__) + logging.get_logger(__name__) if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` @@ -289,17 +295,19 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - infer_state=infer_state) + transformer_outputs = BloomInferenceForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) @@ -314,8 +322,9 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), - shift_labels.view(batch_size * seq_length)) + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -353,11 +362,13 @@ def bloom_for_causal_lm_prepare_inputs_for_generation( else: model_inputs = {"input_ids": input_ids} - model_inputs.update({ - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - }) + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) return model_inputs @staticmethod @@ -416,7 +427,7 @@ def bloom_block_forward( else: outputs = (output,) + outputs[1:] - return outputs # hidden_states, present, attentions + return outputs # hidden_states, present, attentions @staticmethod def bloom_attention_forward( @@ -431,20 +442,19 @@ def bloom_attention_forward( output_attentions: bool = False, infer_state: Optional[BatchInferState] = None, ): - - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) batch_size, q_length, H, D_HEAD = query_layer.shape - k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 - v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 mem_manager = infer_state.cache_manager layer_id = infer_state.decode_layer_id - if layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_length # += 1 + if layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_length # += 1 if infer_state.is_context_stage: # context process @@ -471,9 +481,11 @@ def bloom_attention_forward( if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly cache_k = infer_state.cache_manager.key_buffer[layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_v = infer_state.cache_manager.value_buffer[layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_k.copy_(k) cache_v.copy_(v) else: @@ -486,8 +498,17 @@ def bloom_attention_forward( b_loc = infer_state.block_loc b_seq_len = infer_state.seq_len output = torch.empty_like(q) - token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc, - b_start_loc, b_seq_len, infer_state.cache_manager.past_key_values_length, alibi) + token_attention_fwd( + q, + mem_manager.key_buffer[layer_id], + mem_manager.value_buffer[layer_id], + output, + b_loc, + b_start_loc, + b_seq_len, + infer_state.cache_manager.past_key_values_length, + alibi, + ) context_layer = output.view(batch_size, q_length, H * D_HEAD) @@ -504,8 +525,8 @@ def bloom_attention_forward( output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices):int((i + 1) * slices)], - self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 07b73a6f4ca6..4795162f1980 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -1,6 +1,5 @@ from typing import List, Optional, Tuple -import numpy as np import torch from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm @@ -15,6 +14,7 @@ try: from vllm import layernorm_ops, pos_encoding_ops + rms_norm = layernorm_ops.rms_norm rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox HAS_VLLM_KERNERL = True @@ -29,17 +29,17 @@ def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -71,8 +71,7 @@ def llama_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - - batch_size = input_ids.shape[0] # input_ids.shape[0] + batch_size = input_ids.shape[0] # input_ids.shape[0] infer_state = self.infer_state @@ -103,10 +102,11 @@ def llama_model_forward( if use_cache and seq_length != 1: # NOTE assuem prefill stage # allocate memory block - infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, - infer_state.context_mem_index) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) else: infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) @@ -129,20 +129,20 @@ def llama_model_forward( infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if infer_state.is_context_stage: - infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1) + position_ids.view(-1).shape[0], -1 + ) infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( - position_ids.view(-1).shape[0], -1) + position_ids.view(-1).shape[0], -1 + ) else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) @@ -153,12 +153,13 @@ def llama_model_forward( # embed positions if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, - past_key_values_length) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds @@ -216,7 +217,6 @@ def llama_decoder_layer_forward( use_cache: Optional[bool] = False, infer_state: Optional[BatchInferState] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -261,7 +261,6 @@ def llama_flash_attn_kvcache_forward( use_cache: bool = False, infer_state: Optional[BatchInferState] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - assert use_cache is True, "use_cache should be set to True using this llama attention" bsz, q_len, _ = hidden_states.size() @@ -277,8 +276,8 @@ def llama_flash_attn_kvcache_forward( # NOTE might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) @@ -299,38 +298,62 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, # first token generation # copy key and value calculated in current step to memory manager - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, infer_state.context_mem_index, - infer_state.cache_manager) + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) attn_output = torch.empty_like(query_states) - llama_context_attn_fwd(query_states, key_states, value_states, attn_output, infer_state.start_loc, - infer_state.seq_len, infer_state.cache_manager.past_key_values_length) + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) else: - if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] cache_k.copy_(key_states) cache_v.copy_(value_states) else: # if decode is not contiguous, use triton kernel to copy key and value cache # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head - _copy_kv_to_mem_cache(infer_state.decode_layer_id, key_states, value_states, - infer_state.decode_mem_index, infer_state.cache_manager) + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) # second token and follows # kv = torch.stack((key_states, value_states), dim=2) # (batch_size, seqlen, nheads, headdim) attn_output = torch.empty_like(query_states) - token_attention_fwd(query_states, infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], attn_output, - infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length) + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -341,7 +364,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, def get_llama_vllm_rmsnorm_forward(): - if HAS_VLLM_KERNERL: def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py index 48f8db62c32a..fcb1b6a3bd8f 100644 --- a/colossalai/inference/tensor_parallel/policies/__init__.py +++ b/colossalai/inference/tensor_parallel/policies/__init__.py @@ -1,4 +1,4 @@ from .bloom import BloomModelInferPolicy from .llama import LlamaModelInferPolicy -__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy'] +__all__ = ["BloomModelInferPolicy", "LlamaModelInferPolicy"] diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index cae43aa20421..2d18a3922c1e 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -9,6 +9,7 @@ try: from colossalai.kernel.triton import layer_norm + HAS_TRITON_NORM = True except: print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton") @@ -27,40 +28,40 @@ def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor): class BloomModelInferPolicy(BloomForCausalLMPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel + policy = super().module_policy() # NOTE set inference mode to shard config self.shard_config._infer() method_replacement = { - 'forward': BloomInferenceForwards.bloom_for_causal_lm_forward, - 'prepare_inputs_for_generation': BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation + "forward": BloomInferenceForwards.bloom_for_causal_lm_forward, + "prepare_inputs_for_generation": BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation, } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomForCausalLM) + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomForCausalLM + ) - method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward} + method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel) - method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward} + method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock) - method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomAttention) + method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=BloomAttention + ) if HAS_TRITON_NORM: infer_method = get_triton_layernorm_forward() - method_replacement = {'forward': partial(infer_method)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LayerNorm) + method_replacement = {"forward": partial(infer_method)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LayerNorm + ) return policy diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 4844415d612c..9bbb547dbcae 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -10,6 +10,7 @@ try: from colossalai.kernel.triton import rmsnorm_forward + HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -28,7 +29,6 @@ def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): class LlamaModelInferPolicy(LlamaForCausalLMPolicy): - def __init__(self) -> None: super().__init__() @@ -37,20 +37,20 @@ def module_policy(self): self.shard_config._infer() infer_forward = LlamaInferenceForwards.llama_model_forward - method_replacement = {'forward': partial(infer_forward)} + method_replacement = {"forward": partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaDecoderLayer) + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaAttention) + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) infer_forward = None if HAS_TRITON_RMSNORM: @@ -60,9 +60,9 @@ def module_policy(self): infer_forward = get_llama_vllm_rmsnorm_forward() if infer_forward is not None: - method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=LlamaRMSNorm) + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) return policy diff --git a/colossalai/initialize.py b/colossalai/initialize.py index b8718abc80bd..aac57d34a2c1 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -14,15 +14,17 @@ from colossalai.utils import set_device, set_seed -def launch(config: Union[str, Path, Config, Dict], - rank: int, - world_size: int, - host: str, - port: int, - backend: str = 'nccl', - local_rank: int = None, - seed: int = 1024, - verbose: bool = True): +def launch( + config: Union[str, Path, Config, Dict], + rank: int, + world_size: int, + host: str, + port: int, + backend: str = "nccl", + local_rank: int = None, + seed: int = 1024, + verbose: bool = True, +): """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input arguments are not given. Then initialize and set distributed environment by calling global_context's functions. @@ -46,7 +48,7 @@ def launch(config: Union[str, Path, Config, Dict], warnings.warn("`config` is deprecated and will be removed soon.") # init default process group - init_method = f'tcp://[{host}]:{port}' + init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device @@ -58,15 +60,17 @@ def launch(config: Union[str, Path, Config, Dict], if verbose: logger = get_dist_logger() - logger.info(f'Distributed environment is initialized, world size: {dist.get_world_size()}', ranks=[0]) + logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0]) -def launch_from_slurm(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): +def launch_from_slurm( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables set by SLURM @@ -79,29 +83,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['SLURM_PROCID']) - world_size = int(os.environ['SLURM_NPROCS']) + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NPROCS"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" ) - launch(config=config, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_openmpi(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_openmpi( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables set by OpenMPI @@ -114,29 +122,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) - world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_torch(config: Union[str, Path, Config, Dict], - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_torch( + config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True +): """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size from the environment variables set by PyTorch @@ -147,22 +156,24 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py index 1c3199fc1aff..98b21c9c02c1 100644 --- a/colossalai/interface/__init__.py +++ b/colossalai/interface/__init__.py @@ -1,4 +1,4 @@ from .model import AMPModelMixin, ModelWrapper from .optimizer import OptimizerWrapper -__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin'] +__all__ = ["OptimizerWrapper", "ModelWrapper", "AMPModelMixin"] diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py index 7b3d9435d255..58df09b853ee 100644 --- a/colossalai/interface/model.py +++ b/colossalai/interface/model.py @@ -26,11 +26,9 @@ def forward(self, *args, **kwargs): class AMPModelMixin: - """This mixin class defines the interface for AMP training. - """ + """This mixin class defines the interface for AMP training.""" def update_master_params(self): """ Update the master parameters for AMP training. """ - pass diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index bc270b1d9c89..95d11087bece 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -22,7 +22,7 @@ def parameters(self): params = [] for group in self.param_groups: - params += group['params'] + params += group["params"] return params @property @@ -82,12 +82,14 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: """ nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs) - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> Tensor: + def clip_grad_by_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = False, + *args, + **kwargs, + ) -> Tensor: """ Clips gradient norm of an iterable of parameters. @@ -113,7 +115,8 @@ def scale_loss(self, loss: Tensor): loss (Tensor): The loss to be scaled. """ raise NotImplementedError( - "The method scale_loss is only available for optimizers with mixed precision training") + "The method scale_loss is only available for optimizers with mixed precision training" + ) def unscale_grad(self): """ @@ -122,7 +125,8 @@ def unscale_grad(self): Note: Only available for optimizers with mixed precision training. """ raise NotImplementedError( - "The method unscale_grad is only available for optimizers with mixed precision training") + "The method unscale_grad is only available for optimizers with mixed precision training" + ) def unwrap(self): """ diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index e0136d86e561..f8a974b5fb26 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -4,6 +4,10 @@ from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax __all__ = [ - 'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention', - 'AttnMaskType' + "LayerNorm", + "MultiHeadAttention", + "FusedScaleMaskSoftmax", + "ScaledUpperTriangMaskedSoftmax", + "ColoAttention", + "AttnMaskType", ] diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h index 00066dc95475..a62beef91a8a 100644 --- a/colossalai/kernel/cuda_native/csrc/compat.h +++ b/colossalai/kernel/cuda_native/csrc/compat.h @@ -7,4 +7,4 @@ #define DATA_PTR data_ptr #else #define DATA_PTR data -#endif \ No newline at end of file +#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu index 26efa2ad6f31..9a6a8ebc3983 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu @@ -1,7 +1,6 @@ #include #include - #include "cuda_util.h" /* GPU function guard */ diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu index a39a6dae0f7f..ce0b017f12e1 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -1,1002 +1,1002 @@ -#include -#include - -#include "kernels.h" - -#include - - -namespace cg = cooperative_groups; - -curandStatePhilox4_32_10_t *curandstate; - -/** - * @brief element-wise activation function on device, like Relu, Gelu - * - * @tparam enum class ActivationType, kRelu, kGelu - * @tparam input type - * @param any shape of float and __half2 - * @return same shape and type with input - */ -template -__forceinline__ __device__ T activation_kernel(T x); - -template <> -__device__ float activation_kernel(float x) { - float cdf = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template <> -__device__ __half2 -activation_kernel(__half2 val) { - __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); -} - -template <> -__device__ float activation_kernel(float x) { - return fmaxf(x, 0); -} - -template <> -__device__ __half2 -activation_kernel(__half2 x) { - return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), - fmaxf(0.f, __half2float(x.y))); -} - -/** - * @brief element-wise activation backward function on device - * - * @tparam enum class ActivationType - * @tparam input type - * @param any shape of float and __half2 - * @return same shape of input - */ -template -__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * (dg1 + dg2 + dg3); -} - -template <> -__device__ __half activation_bwd_kernel( - __half grad, __half x_half) { - float x = __half2float(x_half); - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * __float2half(dg1 + dg2 + dg3); -} - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - return x > 0.f ? grad : 0.f; -} - -template <> -__device__ __half -activation_bwd_kernel(__half grad, __half x) { - const __half half_zero = __float2half(0.f); - return x > half_zero ? grad : half_zero; -} - -template <> -__device__ __half2 activation_bwd_kernel( - __half2 grad2, __half2 x_half2) { - const __half half_zero = __float2half(0.f); - return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, - x_half2.y > half_zero ? grad2.y : half_zero); -} - -/** - * @brief init curand states in global memory - * - * @thread grid_dim * block*dim to suuport any size of states - * @param state persistant curand states - * @param seed seed to init states - * @return void - */ -__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, - int seed) { - /* Each thread gets same seed, a different sequence - number, no offset */ - int id = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, id, 0, &state[id]); -} - -void launch_curand_init(int total_count, int dim, cudaStream_t stream) { - cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); - int grid_dim = total_count >> 9; - curand_init_kernel<<>>( - curandstate, std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); -} - -/** - * @brief element-wise dropout, store dropped position in mask, it's not - * in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out any size of float and __half - * @param in same with out - * @param mask uint8 type, same size with out - * @param seed seed to curand - * @return void - */ -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - float *__restrict__ out, - const float *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - - float4 input4 = data4[i]; - float4 res4; - res4.x = input4.x * scale * m[0]; - res4.y = input4.y * scale * m[1]; - res4.z = input4.z * scale * m[2]; - res4.w = input4.w * scale * m[3]; - out4[i] = res4; -} - -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - __half *__restrict__ out, - const __half *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - outs_float4[i] = out_float4; -} - -/** - * @brief element-wise dropout backward with dropout mask, it's - * not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param in any size of float and __half - * @param mask uint8 type, same size with in - * @return void - */ -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - float *out, const float *in, - const uint8_t *__restrict__ mask) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *in4 = reinterpret_cast(in); - const uint32_t *mask4 = reinterpret_cast(mask); - - uint32_t *m4 = reinterpret_cast(m); - m4[0] = mask4[i]; - - float4 input4 = in4[i]; - float4 res4; - res4.x = input4.x * scale * static_cast(m[0]); - res4.y = input4.y * scale * static_cast(m[1]); - res4.z = input4.z * scale * static_cast(m[2]); - res4.w = input4.w * scale * static_cast(m[3]); - out4[i] = res4; -} - -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - __half *out, const __half *in, - const uint8_t *__restrict__ mask) { - const __half scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - float4 *out4 = reinterpret_cast(out); - const float4 *vals_float4 = reinterpret_cast(in); - const uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - uint64_t *m8 = reinterpret_cast(m); - m8[0] = mask8[i]; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - out4[i] = out_float4; -} - -template <> -void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, - int total_count, float ratio, cudaStream_t stream, - bool backward) { - int grid_dim = total_count >> 12; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -template <> -void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, - int total_count, float ratio, - cudaStream_t stream, bool backward) { - int grid_dim = total_count >> 13; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -/** - * @brief fused bias, dropout, and residual at the end of Attention and FFN, - * store dropped position in mask, it's not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param residual [batch_size, seq_len, hidden_size], float and __half - * @param seed seed to curand - * @param hidden_size hidden size - * @return void - */ -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const float *__restrict__ residual, - const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 output4; - - output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; - output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; - output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; - output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; - - out4[i] = output4; -} - -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const __half *__restrict__ residual, - const int seed, const int hidden_size) { - const __half scale = 1. / (1. - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = static_cast(rand.x > ratio); - m[5] = static_cast(rand.y > ratio); - m[6] = static_cast(rand.z > ratio); - m[7] = static_cast(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = m8[0]; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - const __half2 *res_half2 = reinterpret_cast(&res4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = - __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); - out_half2[1] = - __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); - out_half2[2] = - __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); - out_half2[3] = - __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_res_bias(float *out, const float *vals, - uint8_t *mask, const float *bias, - const float *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 12; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, - uint8_t *mask, const __half *bias, - const __half *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 13; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias and dropout backward at the end of Attention and FFN - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, float *__restrict__ in_grad, - float *__restrict__ bias_grad, const float *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - // every block generate 8 bias result - __shared__ float tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - float val = out_grad[idx]; - val *= scale * static_cast(mask[idx]); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - float sum = 0; - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, __half *__restrict__ in_grad, - __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); - __shared__ __half2 tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); - const __half2 *out_grad2 = reinterpret_cast(out_grad); - __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - __half2 local_sum = __float2half2_rn(0.f); - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - __half2 val = out_grad2[idx]; - __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); - val *= scale * m2; - local_sum += val; - in_grad2[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - __half2 sum = __float2half2_rn(0.f); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad2[pos] = tile[0][threadIdx.x]; - } -} - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template <> -void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, - const __half *out_grad, const uint8_t *mask, - int row_size, int dim, float ratio, - cudaStream_t stream) { - dim >>= 1; - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, - const float *out_grad, - const uint8_t *mask, int row_size, - int dim, float ratio, - cudaStream_t stream); - -/** - * @brief fused bias, activation, and dropout at the end of first ffn - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @tparam act_type activation function, like kRelu, kGelu - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param seed seed to curand - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 output4; - - output4.x = - activation_kernel(input4.x + b4.x) * scale * m[0]; - output4.y = - activation_kernel(input4.y + b4.y) * scale * m[1]; - output4.z = - activation_kernel(input4.z + b4.z) * scale * m[2]; - output4.w = - activation_kernel(input4.w + b4.w) * scale * m[3]; - - out4[i] = output4; -} - -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2( - activation_kernel(__hadd2(val_half2[0], b_half2[0])), - scale_mask_1); - out_half2[1] = __hmul2( - activation_kernel(__hadd2(val_half2[1], b_half2[1])), - scale_mask_2); - out_half2[2] = __hmul2( - activation_kernel(__hadd2(val_half2[2], b_half2[2])), - scale_mask_3); - out_half2[3] = __hmul2( - activation_kernel(__hadd2(val_half2[3], b_half2[3])), - scale_mask_4); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias, activation, and dropout backward - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @tparam act_type kRelu - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_bwd_kernel( - const int row_size, const float ratio, T *in_grad, - T *__restrict__ bias_grad, const T *__restrict__ input, - const T *__restrict__ bias, const T *out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - - int stride = hidden_size * WARP_SIZE; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - if (col_idx < hidden_size) { - for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { - float val = out_grad[idx]; - float in = input[idx]; - float b = bias[idx % hidden_size]; - val = activation_bwd_kernel( - val * scale * static_cast(mask[idx]), in + b); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - float sum = tile[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; - __syncthreads(); - - if (threadIdx.y == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -// @brief fused bias, activation, and dropout backward -// It is deprecated for precision reason. Keep it for future optimization. -// -// template -// __global__ void ls_dropout_act_bias_bwd_kernel( -// const int row_size, const float ratio, __half * in_grad, -// __half *__restrict__ bias_grad, const __half *__restrict__ input, const -// __half *__restrict__ bias, const __half * out_grad, const uint8_t -// *__restrict__ mask, const int hidden_size) { -// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); -// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; - -// cg::thread_block b = cg::this_thread_block(); -// cg::thread_block_tile g = cg::tiled_partition(b); - -// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); -// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); -// const __half2 *out_grad2 = reinterpret_cast(out_grad); -// const __half2 *input2 = reinterpret_cast(input); -// const __half2 *bias2 = reinterpret_cast(bias); - -// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - -// int stride = hidden_size * WARP_SIZE; -// __half2 local_sum = __float2half2_rn(0.f); - -// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); -// if (col_idx < hidden_size) { -// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { -// __half2 val = out_grad2[idx]; -// __half2 in2 = input2[idx]; -// __half2 b2 = bias2[idx % hidden_size ]; -// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); -// val = activation_bwd_kernel(val * scale -// * -// m2, -// in2+b2); -// local_sum += val; -// in_grad2[idx] = val; -// idx += stride; -// } -// } - -// tile[threadIdx.x][threadIdx.y] = local_sum; -// __syncthreads(); -// __half2 sum = tile[threadIdx.y][threadIdx.x]; -// __syncthreads(); - -// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - -// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; -// __syncthreads(); - -// if (threadIdx.y == 0) { -// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); -// bias_grad2[pos] = tile[0][threadIdx.x]; -// } -// } - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - ls_dropout_act_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); -} - -// template <> -// void launch_ls_dropout_act_bias_bwd( -// __half *in_grad, __half *bias_grad,const __half *input, const __half -// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int -// dim, float ratio, cudaStream_t stream) { -// dim >>= 1; -// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); -// dim3 block_dim(WARP_SIZE, WARP_SIZE); -// ls_dropout_act_bias_bwd_kernel -// <<>>(row_size, ratio, in_grad, -// bias_grad, -// input, bias,out_grad, mask, dim); -// } - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); +#include +#include + +#include "kernels.h" + +#include + + +namespace cg = cooperative_groups; + +curandStatePhilox4_32_10_t *curandstate; + +/** + * @brief element-wise activation function on device, like Relu, Gelu + * + * @tparam enum class ActivationType, kRelu, kGelu + * @tparam input type + * @param any shape of float and __half2 + * @return same shape and type with input + */ +template +__forceinline__ __device__ T activation_kernel(T x); + +template <> +__device__ float activation_kernel(float x) { + float cdf = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template <> +__device__ __half2 +activation_kernel(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +template <> +__device__ float activation_kernel(float x) { + return fmaxf(x, 0); +} + +template <> +__device__ __half2 +activation_kernel(__half2 x) { + return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), + fmaxf(0.f, __half2float(x.y))); +} + +/** + * @brief element-wise activation backward function on device + * + * @tparam enum class ActivationType + * @tparam input type + * @param any shape of float and __half2 + * @return same shape of input + */ +template +__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * (dg1 + dg2 + dg3); +} + +template <> +__device__ __half activation_bwd_kernel( + __half grad, __half x_half) { + float x = __half2float(x_half); + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * __float2half(dg1 + dg2 + dg3); +} + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + return x > 0.f ? grad : 0.f; +} + +template <> +__device__ __half +activation_bwd_kernel(__half grad, __half x) { + const __half half_zero = __float2half(0.f); + return x > half_zero ? grad : half_zero; +} + +template <> +__device__ __half2 activation_bwd_kernel( + __half2 grad2, __half2 x_half2) { + const __half half_zero = __float2half(0.f); + return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, + x_half2.y > half_zero ? grad2.y : half_zero); +} + +/** + * @brief init curand states in global memory + * + * @thread grid_dim * block*dim to suuport any size of states + * @param state persistant curand states + * @param seed seed to init states + * @return void + */ +__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, + int seed) { + /* Each thread gets same seed, a different sequence + number, no offset */ + int id = threadIdx.x + blockIdx.x * blockDim.x; + curand_init(seed, id, 0, &state[id]); +} + +void launch_curand_init(int total_count, int dim, cudaStream_t stream) { + cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); + int grid_dim = total_count >> 9; + curand_init_kernel<<>>( + curandstate, std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); +} + +/** + * @brief element-wise dropout, store dropped position in mask, it's not + * in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out any size of float and __half + * @param in same with out + * @param mask uint8 type, same size with out + * @param seed seed to curand + * @return void + */ +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + float *__restrict__ out, + const float *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + + float4 input4 = data4[i]; + float4 res4; + res4.x = input4.x * scale * m[0]; + res4.y = input4.y * scale * m[1]; + res4.z = input4.z * scale * m[2]; + res4.w = input4.w * scale * m[3]; + out4[i] = res4; +} + +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + __half *__restrict__ out, + const __half *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + outs_float4[i] = out_float4; +} + +/** + * @brief element-wise dropout backward with dropout mask, it's + * not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param in any size of float and __half + * @param mask uint8 type, same size with in + * @return void + */ +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + float *out, const float *in, + const uint8_t *__restrict__ mask) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *in4 = reinterpret_cast(in); + const uint32_t *mask4 = reinterpret_cast(mask); + + uint32_t *m4 = reinterpret_cast(m); + m4[0] = mask4[i]; + + float4 input4 = in4[i]; + float4 res4; + res4.x = input4.x * scale * static_cast(m[0]); + res4.y = input4.y * scale * static_cast(m[1]); + res4.z = input4.z * scale * static_cast(m[2]); + res4.w = input4.w * scale * static_cast(m[3]); + out4[i] = res4; +} + +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + __half *out, const __half *in, + const uint8_t *__restrict__ mask) { + const __half scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + float4 *out4 = reinterpret_cast(out); + const float4 *vals_float4 = reinterpret_cast(in); + const uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + uint64_t *m8 = reinterpret_cast(m); + m8[0] = mask8[i]; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + out4[i] = out_float4; +} + +template <> +void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, + int total_count, float ratio, cudaStream_t stream, + bool backward) { + int grid_dim = total_count >> 12; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +template <> +void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, + int total_count, float ratio, + cudaStream_t stream, bool backward) { + int grid_dim = total_count >> 13; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +/** + * @brief fused bias, dropout, and residual at the end of Attention and FFN, + * store dropped position in mask, it's not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param residual [batch_size, seq_len, hidden_size], float and __half + * @param seed seed to curand + * @param hidden_size hidden size + * @return void + */ +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const float *__restrict__ residual, + const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 output4; + + output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; + output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; + output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; + output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; + + out4[i] = output4; +} + +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const __half *__restrict__ residual, + const int seed, const int hidden_size) { + const __half scale = 1. / (1. - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = static_cast(rand.x > ratio); + m[5] = static_cast(rand.y > ratio); + m[6] = static_cast(rand.z > ratio); + m[7] = static_cast(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = m8[0]; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + const __half2 *res_half2 = reinterpret_cast(&res4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = + __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); + out_half2[1] = + __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); + out_half2[2] = + __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); + out_half2[3] = + __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_res_bias(float *out, const float *vals, + uint8_t *mask, const float *bias, + const float *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 12; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, + uint8_t *mask, const __half *bias, + const __half *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 13; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias and dropout backward at the end of Attention and FFN + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, float *__restrict__ in_grad, + float *__restrict__ bias_grad, const float *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + // every block generate 8 bias result + __shared__ float tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + float val = out_grad[idx]; + val *= scale * static_cast(mask[idx]); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + float sum = 0; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, __half *__restrict__ in_grad, + __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); + __shared__ __half2 tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); + const __half2 *out_grad2 = reinterpret_cast(out_grad); + __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + __half2 local_sum = __float2half2_rn(0.f); + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + __half2 val = out_grad2[idx]; + __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); + val *= scale * m2; + local_sum += val; + in_grad2[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + __half2 sum = __float2half2_rn(0.f); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad2[pos] = tile[0][threadIdx.x]; + } +} + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template <> +void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, + const __half *out_grad, const uint8_t *mask, + int row_size, int dim, float ratio, + cudaStream_t stream) { + dim >>= 1; + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, + const float *out_grad, + const uint8_t *mask, int row_size, + int dim, float ratio, + cudaStream_t stream); + +/** + * @brief fused bias, activation, and dropout at the end of first ffn + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @tparam act_type activation function, like kRelu, kGelu + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param seed seed to curand + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 output4; + + output4.x = + activation_kernel(input4.x + b4.x) * scale * m[0]; + output4.y = + activation_kernel(input4.y + b4.y) * scale * m[1]; + output4.z = + activation_kernel(input4.z + b4.z) * scale * m[2]; + output4.w = + activation_kernel(input4.w + b4.w) * scale * m[3]; + + out4[i] = output4; +} + +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2( + activation_kernel(__hadd2(val_half2[0], b_half2[0])), + scale_mask_1); + out_half2[1] = __hmul2( + activation_kernel(__hadd2(val_half2[1], b_half2[1])), + scale_mask_2); + out_half2[2] = __hmul2( + activation_kernel(__hadd2(val_half2[2], b_half2[2])), + scale_mask_3); + out_half2[3] = __hmul2( + activation_kernel(__hadd2(val_half2[3], b_half2[3])), + scale_mask_4); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias, activation, and dropout backward + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @tparam act_type kRelu + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_bwd_kernel( + const int row_size, const float ratio, T *in_grad, + T *__restrict__ bias_grad, const T *__restrict__ input, + const T *__restrict__ bias, const T *out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + + int stride = hidden_size * WARP_SIZE; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + if (col_idx < hidden_size) { + for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { + float val = out_grad[idx]; + float in = input[idx]; + float b = bias[idx % hidden_size]; + val = activation_bwd_kernel( + val * scale * static_cast(mask[idx]), in + b); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + float sum = tile[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; + __syncthreads(); + + if (threadIdx.y == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +// @brief fused bias, activation, and dropout backward +// It is deprecated for precision reason. Keep it for future optimization. +// +// template +// __global__ void ls_dropout_act_bias_bwd_kernel( +// const int row_size, const float ratio, __half * in_grad, +// __half *__restrict__ bias_grad, const __half *__restrict__ input, const +// __half *__restrict__ bias, const __half * out_grad, const uint8_t +// *__restrict__ mask, const int hidden_size) { +// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); +// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; + +// cg::thread_block b = cg::this_thread_block(); +// cg::thread_block_tile g = cg::tiled_partition(b); + +// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); +// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); +// const __half2 *out_grad2 = reinterpret_cast(out_grad); +// const __half2 *input2 = reinterpret_cast(input); +// const __half2 *bias2 = reinterpret_cast(bias); + +// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + +// int stride = hidden_size * WARP_SIZE; +// __half2 local_sum = __float2half2_rn(0.f); + +// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); +// if (col_idx < hidden_size) { +// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { +// __half2 val = out_grad2[idx]; +// __half2 in2 = input2[idx]; +// __half2 b2 = bias2[idx % hidden_size ]; +// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); +// val = activation_bwd_kernel(val * scale +// * +// m2, +// in2+b2); +// local_sum += val; +// in_grad2[idx] = val; +// idx += stride; +// } +// } + +// tile[threadIdx.x][threadIdx.y] = local_sum; +// __syncthreads(); +// __half2 sum = tile[threadIdx.y][threadIdx.x]; +// __syncthreads(); + +// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + +// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; +// __syncthreads(); + +// if (threadIdx.y == 0) { +// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); +// bias_grad2[pos] = tile[0][threadIdx.x]; +// } +// } + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + ls_dropout_act_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); +} + +// template <> +// void launch_ls_dropout_act_bias_bwd( +// __half *in_grad, __half *bias_grad,const __half *input, const __half +// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int +// dim, float ratio, cudaStream_t stream) { +// dim >>= 1; +// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); +// dim3 block_dim(WARP_SIZE, WARP_SIZE); +// ls_dropout_act_bias_bwd_kernel +// <<>>(row_size, ratio, in_grad, +// bias_grad, +// input, bias,out_grad, mask, dim); +// } + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu index bc90c54c0a00..625b02cd25d9 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu @@ -1,232 +1,232 @@ -#include - -#include "kernels.h" - -namespace cg = cooperative_groups; - -/** -@brief: fuse_transpose_bias -Calculate the sum of elements in each column of the matrix. - -@thread -gridDim.x = ceil(cols / WARP_SIZE) -blockDim.x = WARP_SIZE -blockDim.y = WARP_SIZE - -@param -inp: [rows, cols] -out: [cols] -rows: the number of rows in the matrix -cols: the number of cols in the matrix -*/ -template -__global__ void column_sum_reduce(const T *__restrict__ inp, - T *__restrict__ out, int rows, int cols) { - __shared__ float tile[WARP_SIZE][WARP_SIZE]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - int y_stride = cols * WARP_SIZE; - float localSum = 0; - - // Loop across matrix row - // TODO: optimize to log complexity - if (idx < cols) { - int offset = flat_2dim(threadIdx.y, idx, cols); - for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - // The sum of a row in tile is equal to the sum of a col in original matrix - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - // The change of threadIdx.x is continuous - float sum = tile[threadIdx.y][threadIdx.x]; - - __syncthreads(); - - // Calculate the sum of a row in tile - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); - if (pos < cols) out[pos] = sum; - } -} - -// [r, c] -> [c] -template <> -void launch_fuse_transpose_bias_kernel(const float *inp, float *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce - <<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce<__half> - <<>>(inp, out, rows, cols); -} - -/** -@brief: fused_add2 -Add two matrix inp1 and inp2 to out. - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -inp1: [batch_size, seq_len, hidden_dim] -inp2: [batch_size, seq_len, hidden_dim] -out: [batch_size, seq_len, hidden_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -*/ -template -__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, - int hidden_dim); - -template <> -__global__ void fused_add2_kernel(float *out, const float *inp1, - const float *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - val.x = vinp1.x + vinp2.x; - val.y = vinp1.y + vinp2.y; - val.z = vinp1.z + vinp2.z; - val.w = vinp1.w + vinp2.w; - out_4[offset + i] = val; - } -} - -template <> -__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, - const __half *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); - __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); - __half2 *h2_val = reinterpret_cast<__half2 *>(&val); - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); - h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); - h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); - h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); - out_4[offset + i] = val; - } -} - -//[b, s, h] -> [b, s, h] -template <> -void launch_fused_add2(float *out, const float *inp1, const float *inp2, - int batch_size, int seq_len, int hidden_dim, - cudaStream_t &stream) { - hidden_dim >>= 2; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template <> -void launch_fused_add2<__half>(__half *out, const __half *inp1, - const __half *inp2, int batch_size, int seq_len, - int hidden_dim, cudaStream_t &stream) { - hidden_dim >>= 3; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template -__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, - int sz0, int sz2, int sz1_1, int sz1_2) { - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); - if (idx >= nele) { - return; - } - float4 *dst_ptr = (float4 *)output + idx; - int idx2 = idx % sz2; - idx = idx / sz2; - int idx1 = idx % (sz1_1 + sz1_2); - int idx0 = idx / (sz1_1 + sz1_2); - float4 *src_ptr = nullptr; - int sz1 = 0; - if (idx1 < sz1_1) { - sz1 = sz1_1; - src_ptr = (float4 *)inp1; - } else { - idx1 -= sz1_1; - sz1 = sz1_2; - src_ptr = (float4 *)inp2; - } - src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); - dst_ptr[0] = src_ptr[0]; -} - -template <> -void launch_concat3_dim1(const float *inp1, const float *inp2, - float *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 2; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} - -template <> -void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, - __half *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 3; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} +#include + +#include "kernels.h" + +namespace cg = cooperative_groups; + +/** +@brief: fuse_transpose_bias +Calculate the sum of elements in each column of the matrix. + +@thread +gridDim.x = ceil(cols / WARP_SIZE) +blockDim.x = WARP_SIZE +blockDim.y = WARP_SIZE + +@param +inp: [rows, cols] +out: [cols] +rows: the number of rows in the matrix +cols: the number of cols in the matrix +*/ +template +__global__ void column_sum_reduce(const T *__restrict__ inp, + T *__restrict__ out, int rows, int cols) { + __shared__ float tile[WARP_SIZE][WARP_SIZE]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + int y_stride = cols * WARP_SIZE; + float localSum = 0; + + // Loop across matrix row + // TODO: optimize to log complexity + if (idx < cols) { + int offset = flat_2dim(threadIdx.y, idx, cols); + for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + // The sum of a row in tile is equal to the sum of a col in original matrix + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + // The change of threadIdx.x is continuous + float sum = tile[threadIdx.y][threadIdx.x]; + + __syncthreads(); + + // Calculate the sum of a row in tile + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); + if (pos < cols) out[pos] = sum; + } +} + +// [r, c] -> [c] +template <> +void launch_fuse_transpose_bias_kernel(const float *inp, float *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce + <<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce<__half> + <<>>(inp, out, rows, cols); +} + +/** +@brief: fused_add2 +Add two matrix inp1 and inp2 to out. + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +inp1: [batch_size, seq_len, hidden_dim] +inp2: [batch_size, seq_len, hidden_dim] +out: [batch_size, seq_len, hidden_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +*/ +template +__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, + int hidden_dim); + +template <> +__global__ void fused_add2_kernel(float *out, const float *inp1, + const float *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + val.x = vinp1.x + vinp2.x; + val.y = vinp1.y + vinp2.y; + val.z = vinp1.z + vinp2.z; + val.w = vinp1.w + vinp2.w; + out_4[offset + i] = val; + } +} + +template <> +__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, + const __half *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); + __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); + __half2 *h2_val = reinterpret_cast<__half2 *>(&val); + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); + h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); + h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); + h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); + out_4[offset + i] = val; + } +} + +//[b, s, h] -> [b, s, h] +template <> +void launch_fused_add2(float *out, const float *inp1, const float *inp2, + int batch_size, int seq_len, int hidden_dim, + cudaStream_t &stream) { + hidden_dim >>= 2; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template <> +void launch_fused_add2<__half>(__half *out, const __half *inp1, + const __half *inp2, int batch_size, int seq_len, + int hidden_dim, cudaStream_t &stream) { + hidden_dim >>= 3; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template +__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, + int sz0, int sz2, int sz1_1, int sz1_2) { + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); + if (idx >= nele) { + return; + } + float4 *dst_ptr = (float4 *)output + idx; + int idx2 = idx % sz2; + idx = idx / sz2; + int idx1 = idx % (sz1_1 + sz1_2); + int idx0 = idx / (sz1_1 + sz1_2); + float4 *src_ptr = nullptr; + int sz1 = 0; + if (idx1 < sz1_1) { + sz1 = sz1_1; + src_ptr = (float4 *)inp1; + } else { + idx1 -= sz1_1; + sz1 = sz1_2; + src_ptr = (float4 *)inp2; + } + src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); + dst_ptr[0] = src_ptr[0]; +} + +template <> +void launch_concat3_dim1(const float *inp1, const float *inp2, + float *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 2; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} + +template <> +void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, + __half *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 3; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h index 563a7fe284a3..025fbf3f8f15 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h @@ -1,96 +1,96 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -template -class Dropout { - public: - struct Config { - float ratio; - bool training; - - Config(float r) : ratio(r), training(true) {} - float RATIO() const { return training ? ratio : 0.0; } - }; - - Dropout(const Config &config, size_t max_ele_num) - : _config(config), _mask(nullptr) { - _mask = cuda_malloc(max_ele_num); - } - - virtual ~Dropout() { cuda_free(_mask); } - - // after attention softmax - void dropout(T *output, const T *input, int count, cudaStream_t stream, - bool bwd = false) { - launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, - bwd); - } - - void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { - launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), - stream, true); - } - - // transformer layer's postprocessing dropout, after attn or ffn module, - // before residual add. - void bias_dropout_residual(T *output, const T *input, const T *residual, - const T *bias, int rows, int cols, - cudaStream_t stream) { - launch_ls_dropout_res_bias(output, input, _mask, bias, residual, - rows * cols, cols, _config.RATIO(), stream); - } - - void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, - int rows, int cols, cudaStream_t stream) { - launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, - _config.RATIO(), stream); - } - - // dropout inside ffn. - void bias_act_dropout(T *output, const T *input, const T *bias, int rows, - int cols, std::string activation_fn, - cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, - const T *bias, int rows, int cols, - std::string activation_fn, cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - private: - uint8_t *_mask; - Config _config; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +template +class Dropout { + public: + struct Config { + float ratio; + bool training; + + Config(float r) : ratio(r), training(true) {} + float RATIO() const { return training ? ratio : 0.0; } + }; + + Dropout(const Config &config, size_t max_ele_num) + : _config(config), _mask(nullptr) { + _mask = cuda_malloc(max_ele_num); + } + + virtual ~Dropout() { cuda_free(_mask); } + + // after attention softmax + void dropout(T *output, const T *input, int count, cudaStream_t stream, + bool bwd = false) { + launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, + bwd); + } + + void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { + launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), + stream, true); + } + + // transformer layer's postprocessing dropout, after attn or ffn module, + // before residual add. + void bias_dropout_residual(T *output, const T *input, const T *residual, + const T *bias, int rows, int cols, + cudaStream_t stream) { + launch_ls_dropout_res_bias(output, input, _mask, bias, residual, + rows * cols, cols, _config.RATIO(), stream); + } + + void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, + int rows, int cols, cudaStream_t stream) { + launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, + _config.RATIO(), stream); + } + + // dropout inside ffn. + void bias_act_dropout(T *output, const T *input, const T *bias, int rows, + int cols, std::string activation_fn, + cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, + const T *bias, int rows, int cols, + std::string activation_fn, cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + private: + uint8_t *_mask; + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h index fbb9c5465c24..735e1363cc46 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h @@ -3,10 +3,11 @@ #include #include #include -#include #include #include +#include + #define MAX_THREADS 1024 #define WARP_SIZE 32 @@ -132,8 +133,9 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, } /* Convert 4-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int -flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) { +__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, + int id4, int dim2, int dim3, + int dim4) { // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; int res = id4; @@ -201,9 +203,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, } /* Convert vector index to 6-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, - int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) { +__forceinline__ __host__ __device__ void decompose_6dim( + int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, + int *id1, int *id2, int *id3, int *id4, int *id5) { *id5 = src % dim5; src /= dim5; @@ -221,9 +223,11 @@ decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, } /* Convert vector index to 5-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0, - int *id1, int *id2, int *id3, int *id4) { +__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, + int dim2, int dim3, + int dim4, int *id0, + int *id1, int *id2, + int *id3, int *id4) { *id4 = src % dim4; src /= dim4; @@ -253,8 +257,9 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, } /* Convert vector index to 3-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, + int dim2, int *id0, + int *id1, int *id2) { *id2 = src % dim2; src /= dim2; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h index ded5c0fdcbee..a7767e187ffc 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h @@ -1,64 +1,65 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template class Normalize_Layer { -public: - struct Config { - uint32_t hidden_dim; - bool use_mean; - Config(uint32_t hidden_dim, bool use_mean = false) - : hidden_dim(hidden_dim), use_mean(use_mean) {} - }; - - Normalize_Layer(Config config, size_t max_rows) - : config_(config), vars_(nullptr), means_(nullptr) { - vars_ = cuda_malloc(max_rows); - if (config_.use_mean) { - means_ = cuda_malloc(max_rows); - } - } - - ~Normalize_Layer() { - cuda_free(vars_); - cuda_free(means_); - } - - void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, - int batch_size, cudaStream_t stream) { - launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, - config_.hidden_dim, stream); - } - - /* - residual_grad, inp_or_out, betta should be treated carefully. - inp_or_out = input if use_mean else output - residual_grad, betta can be nullptr. - residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln - betta are only used to compute xhat, - (use_mean == false) ^ (betta == nullptr) should be true - */ - void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, int batch_size, cudaStream_t stream[2]) { - launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, - inp_or_out, gamma, betta, vars_, means_, batch_size, - config_.hidden_dim, stream); - } - - inline bool use_mean() const { return config_.use_mean; } - -private: - Config config_; - T *vars_; - T *means_; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Normalize_Layer { + public: + struct Config { + uint32_t hidden_dim; + bool use_mean; + Config(uint32_t hidden_dim, bool use_mean = false) + : hidden_dim(hidden_dim), use_mean(use_mean) {} + }; + + Normalize_Layer(Config config, size_t max_rows) + : config_(config), vars_(nullptr), means_(nullptr) { + vars_ = cuda_malloc(max_rows); + if (config_.use_mean) { + means_ = cuda_malloc(max_rows); + } + } + + ~Normalize_Layer() { + cuda_free(vars_); + cuda_free(means_); + } + + void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, + int batch_size, cudaStream_t stream) { + launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, + config_.hidden_dim, stream); + } + + /* + residual_grad, inp_or_out, betta should be treated carefully. + inp_or_out = input if use_mean else output + residual_grad, betta can be nullptr. + residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln + betta are only used to compute xhat, + (use_mean == false) ^ (betta == nullptr) should be true + */ + void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, int batch_size, cudaStream_t stream[2]) { + launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, + inp_or_out, gamma, betta, vars_, means_, batch_size, + config_.hidden_dim, stream); + } + + inline bool use_mean() const { return config_.use_mean; } + + private: + Config config_; + T *vars_; + T *means_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h index ec447ad84c54..b917abaf0336 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -1,42 +1,42 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Softmax { - public: - struct Config { - size_t nhead; - Config(size_t nhead) : nhead(nhead) {} - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, - int to_len, cudaStream_t &stream, bool mask_future = true) { - launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, - to_len, mask_future, stream); - } - - void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, - int to_len, cudaStream_t stream) { - launch_attn_softmax_bw(out_grad, soft_out, - batch_size * config_.nhead * from_len, to_len, - stream); - } - - void reset_size(size_t nhead) { config_.nhead = nhead; } - - private: - Config config_; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Softmax { + public: + struct Config { + size_t nhead; + Config(size_t nhead) : nhead(nhead) {} + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, + int to_len, cudaStream_t &stream, bool mask_future = true) { + launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, + to_len, mask_future, stream); + } + + void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, + int to_len, cudaStream_t stream) { + launch_attn_softmax_bw(out_grad, soft_out, + batch_size * config_.nhead * from_len, to_len, + stream); + } + + void reset_size(size_t nhead) { config_.nhead = nhead; } + + private: + Config config_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu index 3e61d4e35832..e2f1869b165e 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu @@ -1,1169 +1,1172 @@ -#include "block_reduce.h" -#include "kernels.h" -#include - -namespace cg = cooperative_groups; -const float LN_EPSILON = 1e-8f; -#define TILE_DIM 32 - -template __forceinline__ __device__ T add_eps(T x) { - return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); -} - -/** -@brief: ker_layer_norm -Standard layer normalization. -It will not only output the layer norm result, - but also outputs variance. - may also output means, depends on whether - the means argument is nullptr - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -ln_res: [batch_size* seq_len, hidden_size], ln result. -vars: [batch_size* seq_len], variance per token -means: [batch_size* seq_len], means per token, can be nullput -inp: [batch_size * seq_len, hidden_size], ln input. -scale: [hidden_size], ln scale -bias: [hidden_size], ln bias -*/ -template -__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val = inp_f4[idx]; - l_sum += val.x + val.y + val.z + val.w; - l_square_sum += - val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 4.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 vscale = __ldg((const float4 *)scale + idx); - float4 vbias = __ldg((const float4 *)bias + idx); - float4 val = inp_f4[idx]; - val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; - val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; - val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; - val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; - output_f4[idx] = val; - } -} - -template <> -__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, - __half *means, const __half *inp, - const __half *scale, const __half *bias, - int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 val_f2 = __half22float2(val_h2[i]); - l_sum += val_f2.x + val_f2.y; - l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; - } - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 8.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - // load scale, bias, input - float4 scale_f4 = __ldg((const float4 *)scale + idx); - __half2 *scale_h2 = (__half2 *)(&scale_f4); - float4 bias_f4 = __ldg((const float4 *)bias + idx); - __half2 *bias_h2 = (__half2 *)(&bias_f4); - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); - -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 scale_f2 = __half22float2(scale_h2[i]); - float2 bias_f2 = __half22float2(bias_h2[i]); - float2 val_f2 = __half22float2(val_h2[i]); - val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; - val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; - val_h2[i] = __float22half2_rn(val_f2); - } - output_f4[idx] = val_f4; - } -} - -// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x -// * val_f2_1.x + val_f2_1.y * val_f2_1.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 2; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_h2[i] = __float22half2_rn(val_f2); -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// } -// } - -// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// float4 val_f4_2 = inp_f4[idx+2]; -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + -// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * -// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x -// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + -// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + -// val_f2_3.y * val_f2_3.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 4; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); -// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); -// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); -// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); -// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); -// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); -// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// float4 val_f4_2 = inp_f4[idx+2]; -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); -// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); -// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * -// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var -// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * -// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) -// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = -// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); -// val_h2_2[i] = __float22half2_rn(val_f2_2); -// val_h2_3[i] = __float22half2_rn(val_f2_3); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// output_f4[idx+2] = val_f4_2; -// output_f4[idx+3] = val_f4_3; -// } -// } - -template <> -void launch_layer_norm(float *ln_res, float *vars, float *means, - const float *inp, const float *scale, - const float *bias, int batch_size, int hidden_dim, - cudaStream_t stream) { - if (hidden_dim % 4 != 0) { - throw std::runtime_error("violate hidden_dim % 4 = 0"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, - const __half *inp, const __half *scale, - const __half *bias, int batch_size, - int hidden_dim, cudaStream_t stream) { - if (hidden_dim % 8 != 0) { - throw std::runtime_error("violate hidden_dim % 8 = 0"); - } - hidden_dim >>= 3; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<__half><<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); - // if (hidden_dim % 8 != 0) { - // throw std::runtime_error("violate hidden_dim % 8 = 0"); - // } - // hidden_dim >>= 3; - - // if (hidden_dim * 8 < 8192) { - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm<__half><<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { - // hidden_dim >>= 1; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x2<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { - // hidden_dim >>= 2; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x4<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else { - // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - // } -} - -/** -@brief: ker_ln_bw_dgamma_dbetta -Layer norm backword kernel, compute the gradient of gamma and betta. -dbetta = sum(dout, dim=0) -dgamma = sum(xhat * dout, dim=0) -xhat = (input - mean) * rsqrt(var) or - (output - betta) / gamma - - -@thread -gridDim.x = hidden_size / 32 -blockDim.x = 32 -blockDim.y = 32 - -@param -gamma_grad: [hidden_size], gradient of gamma -betta_grad: [hidden_size], gradient of betta -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat, maybe nullptr -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat, maybe nullptr -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -(gamma && betta) ^ (vars && means) should be true -*/ -template -__global__ void -ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad, - const T *inp_or_out, const T *gamma, const T *betta, - const T *vars, const T *means, int rows, int width) { - __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - // Loop across inp height - float dbetta = 0; - float dgamma = 0; - float dout, val; - if (idx < width) { - if (means == nullptr) { - float vbetta = (float)betta[idx]; - float vgamma = (float)gamma[idx]; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is output - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - vbetta) / add_eps(vgamma) * dout); - offset += y_stride; - } - } else { - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is input - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - (float)means[r]) * - rsqrtf((float)vars[r] + LN_EPSILON) * dout); - offset += y_stride; - } - } - } - - // Sum the shared buffer. - betta_buffer[threadIdx.x][threadIdx.y] = dbetta; - gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; - __syncthreads(); - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (threadIdx.x == 0 && idx < width) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/** -@brief: ker_ln_bw_dinp -Layer norm backword kernel, compute the gradient of input. -dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) - * rsqrt(var) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dxhat = dout * gamma - - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, - usually appear in pre-layer-norm for transformer layer, maybe nullptr -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat and dxhat -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat and dinp -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -*/ -template -__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, - const T *gamma, const T *betta, const T *vars, - const T *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - float4 dxhat, xhat; - float var_rsqrt; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - dxhat = ((const float4 *)out_grad)[offset]; - float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; - dxhat.x *= vgamma.x; - dxhat.y *= vgamma.y; - dxhat.z *= vgamma.z; - dxhat.w *= vgamma.w; - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - xhat = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); - xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); - xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); - xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; - xhat.x = (xhat.x - fmean) * var_rsqrt; - xhat.y = (xhat.y - fmean) * var_rsqrt; - xhat.z = (xhat.z - fmean) * var_rsqrt; - xhat.w = (xhat.w - fmean) * var_rsqrt; - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - float reduce_val[2] = {0.f, 0.f}; - if (threadIdx.x < hidden_dim) { - reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; - reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + - dxhat.w * xhat.w; - } - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - dxhat.x += dresidual.x; - dxhat.y += dresidual.y; - dxhat.z += dresidual.z; - dxhat.w += dresidual.w; - } - ((float4 *)inp_grad)[offset] = dxhat; -} - -template <> -__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - - float2 dxhat[4], xhat[4]; - float var_rsqrt; - float4 vtmp; - __half2 *tmp_h2; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vbetta = __half22float2(betta_h2[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; -} - -__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float var_rsqrt; - float4 vtmp, vtmp_1; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 2; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; -} - -__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float2 dxhat_2[4], xhat_2[4]; - float2 dxhat_3[4], xhat_3[4]; - float var_rsqrt; - float4 vtmp, vtmp_1, vtmp_2, vtmp_3; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - __half2 *tmp_h2_2; - __half2 *tmp_h2_3; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - vtmp_2 = ((const float4 *)out_grad)[offset + 2]; - vtmp_3 = ((const float4 *)out_grad)[offset + 3]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); - tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; - float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; - float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); - __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); - __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vdout_2 = __half22float2(tmp_h2_2[i]); - float2 vdout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - dxhat_2[i].x = vdout_2.x * vgamma_2.x; - dxhat_2[i].y = vdout_2.y * vgamma_2.y; - dxhat_3[i].x = vdout_3.x * vgamma_3.x; - dxhat_3[i].y = vdout_3.y * vgamma_3.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + - dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + - dxhat_3[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; - vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; - float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; - float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); - __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); - __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vout_2 = __half22float2(tmp_h2_2[i]); - float2 vout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - float2 vbetta_2 = __half22float2(betta_h2_2[i]); - float2 vbetta_3 = __half22float2(betta_h2_3[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); - xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); - xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - float2 vinp_2 = __half22float2(tmp_h2_2[i]); - float2 vinp_3 = __half22float2(tmp_h2_3[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; - xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; - xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; - float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); - __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); - __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_2[2 * i])); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_3[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; - ((float4 *)inp_grad)[offset + 2] = vtmp_2; - ((float4 *)inp_grad)[offset + 3] = vtmp_3; -} - -/** -Layer norm backword, - compute the gradient of gamma, betta and input. -dbetta = sum(dout, dim=0) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dgamma = sum(xhat * dout, dim=0) -dxhat = dout * gamma -dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) - * rsqrt(var) - -residual_grad, means, betta can be nullptr. -residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln -means and betta are only used to compute xhat, - (means == nullptr) ^ (betta == nullptr) should be true -*/ -template <> -void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, - const float *out_grad, const float *residual_grad, - const float *inp_or_out, const float *gamma, - const float *betta, const float *vars, - const float *means, int batch, int hidden_dim, - cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 4 != 0 || hidden_dim > 4096) { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); -} - -template <> -void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, - __half *inp_grad, const __half *out_grad, - const __half *residual_grad, const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, int batch, - int hidden_dim, cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<__half><<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 8 != 0) { - throw std::runtime_error("hidden_dim % 8 != 0"); - } - hidden_dim >>= 3; - - if (hidden_dim * 8 <= 8192) { - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { - hidden_dim >>= 1; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x2<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x4<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - } -} +#include + +#include "block_reduce.h" +#include "kernels.h" + +namespace cg = cooperative_groups; +const float LN_EPSILON = 1e-8f; +#define TILE_DIM 32 + +template +__forceinline__ __device__ T add_eps(T x) { + return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); +} + +/** +@brief: ker_layer_norm +Standard layer normalization. +It will not only output the layer norm result, + but also outputs variance. + may also output means, depends on whether + the means argument is nullptr + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +ln_res: [batch_size* seq_len, hidden_size], ln result. +vars: [batch_size* seq_len], variance per token +means: [batch_size* seq_len], means per token, can be nullput +inp: [batch_size * seq_len, hidden_size], ln input. +scale: [hidden_size], ln scale +bias: [hidden_size], ln bias +*/ +template +__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val = inp_f4[idx]; + l_sum += val.x + val.y + val.z + val.w; + l_square_sum += + val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 4.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 vscale = __ldg((const float4 *)scale + idx); + float4 vbias = __ldg((const float4 *)bias + idx); + float4 val = inp_f4[idx]; + val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; + val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; + val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; + val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; + output_f4[idx] = val; + } +} + +template <> +__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, + __half *means, const __half *inp, + const __half *scale, const __half *bias, + int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 val_f2 = __half22float2(val_h2[i]); + l_sum += val_f2.x + val_f2.y; + l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; + } + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 8.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + // load scale, bias, input + float4 scale_f4 = __ldg((const float4 *)scale + idx); + __half2 *scale_h2 = (__half2 *)(&scale_f4); + float4 bias_f4 = __ldg((const float4 *)bias + idx); + __half2 *bias_h2 = (__half2 *)(&bias_f4); + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); + +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 scale_f2 = __half22float2(scale_h2[i]); + float2 bias_f2 = __half22float2(bias_h2[i]); + float2 val_f2 = __half22float2(val_h2[i]); + val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; + val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; + val_h2[i] = __float22half2_rn(val_f2); + } + output_f4[idx] = val_f4; + } +} + +// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; +// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x +// * val_f2_1.x + val_f2_1.y * val_f2_1.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 2; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_h2[i] = __float22half2_rn(val_f2); +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// } +// } + +// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// float4 val_f4_2 = inp_f4[idx+2]; +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + +// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * +// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x +// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + +// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + +// val_f2_3.y * val_f2_3.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 4; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); +// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); +// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); +// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); +// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); +// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); +// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// float4 val_f4_2 = inp_f4[idx+2]; +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); +// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); +// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * +// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var +// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * +// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) +// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = +// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); +// val_h2_2[i] = __float22half2_rn(val_f2_2); +// val_h2_3[i] = __float22half2_rn(val_f2_3); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// output_f4[idx+2] = val_f4_2; +// output_f4[idx+3] = val_f4_3; +// } +// } + +template <> +void launch_layer_norm(float *ln_res, float *vars, float *means, + const float *inp, const float *scale, + const float *bias, int batch_size, int hidden_dim, + cudaStream_t stream) { + if (hidden_dim % 4 != 0) { + throw std::runtime_error("violate hidden_dim % 4 = 0"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); +} + +template <> +void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, + const __half *inp, const __half *scale, + const __half *bias, int batch_size, + int hidden_dim, cudaStream_t stream) { + if (hidden_dim % 8 != 0) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + hidden_dim >>= 3; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<__half><<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); + // if (hidden_dim % 8 != 0) { + // throw std::runtime_error("violate hidden_dim % 8 = 0"); + // } + // hidden_dim >>= 3; + + // if (hidden_dim * 8 < 8192) { + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm<__half><<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { + // hidden_dim >>= 1; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x2<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { + // hidden_dim >>= 2; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x4<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else { + // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + // } +} + +/** +@brief: ker_ln_bw_dgamma_dbetta +Layer norm backword kernel, compute the gradient of gamma and betta. +dbetta = sum(dout, dim=0) +dgamma = sum(xhat * dout, dim=0) +xhat = (input - mean) * rsqrt(var) or + (output - betta) / gamma + + +@thread +gridDim.x = hidden_size / 32 +blockDim.x = 32 +blockDim.y = 32 + +@param +gamma_grad: [hidden_size], gradient of gamma +betta_grad: [hidden_size], gradient of betta +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat, maybe nullptr +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat, maybe nullptr +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +(gamma && betta) ^ (vars && means) should be true +*/ +template +__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, + const T *out_grad, const T *inp_or_out, + const T *gamma, const T *betta, + const T *vars, const T *means, int rows, + int width) { + __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + // Loop across inp height + float dbetta = 0; + float dgamma = 0; + float dout, val; + if (idx < width) { + if (means == nullptr) { + float vbetta = (float)betta[idx]; + float vgamma = (float)gamma[idx]; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is output + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - vbetta) / add_eps(vgamma) * dout); + offset += y_stride; + } + } else { + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is input + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - (float)means[r]) * + rsqrtf((float)vars[r] + LN_EPSILON) * dout); + offset += y_stride; + } + } + } + + // Sum the shared buffer. + betta_buffer[threadIdx.x][threadIdx.y] = dbetta; + gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; + __syncthreads(); + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (threadIdx.x == 0 && idx < width) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/** +@brief: ker_ln_bw_dinp +Layer norm backword kernel, compute the gradient of input. +dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) + * rsqrt(var) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dxhat = dout * gamma + + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, + usually appear in pre-layer-norm for transformer layer, maybe nullptr +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat and dxhat +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat and dinp +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +*/ +template +__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, + const T *gamma, const T *betta, const T *vars, + const T *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + float4 dxhat, xhat; + float var_rsqrt; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + dxhat = ((const float4 *)out_grad)[offset]; + float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; + dxhat.x *= vgamma.x; + dxhat.y *= vgamma.y; + dxhat.z *= vgamma.z; + dxhat.w *= vgamma.w; + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + xhat = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); + xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); + xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); + xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; + xhat.x = (xhat.x - fmean) * var_rsqrt; + xhat.y = (xhat.y - fmean) * var_rsqrt; + xhat.z = (xhat.z - fmean) * var_rsqrt; + xhat.w = (xhat.w - fmean) * var_rsqrt; + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + float reduce_val[2] = {0.f, 0.f}; + if (threadIdx.x < hidden_dim) { + reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; + reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + + dxhat.w * xhat.w; + } + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + dxhat.x += dresidual.x; + dxhat.y += dresidual.y; + dxhat.z += dresidual.z; + dxhat.w += dresidual.w; + } + ((float4 *)inp_grad)[offset] = dxhat; +} + +template <> +__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, + int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + + float2 dxhat[4], xhat[4]; + float var_rsqrt; + float4 vtmp; + __half2 *tmp_h2; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vbetta = __half22float2(betta_h2[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; +} + +__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float var_rsqrt; + float4 vtmp, vtmp_1; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 2; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; +} + +__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float2 dxhat_2[4], xhat_2[4]; + float2 dxhat_3[4], xhat_3[4]; + float var_rsqrt; + float4 vtmp, vtmp_1, vtmp_2, vtmp_3; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + __half2 *tmp_h2_2; + __half2 *tmp_h2_3; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + vtmp_2 = ((const float4 *)out_grad)[offset + 2]; + vtmp_3 = ((const float4 *)out_grad)[offset + 3]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); + tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; + float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; + float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); + __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); + __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vdout_2 = __half22float2(tmp_h2_2[i]); + float2 vdout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + dxhat_2[i].x = vdout_2.x * vgamma_2.x; + dxhat_2[i].y = vdout_2.y * vgamma_2.y; + dxhat_3[i].x = vdout_3.x * vgamma_3.x; + dxhat_3[i].y = vdout_3.y * vgamma_3.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + + dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + + dxhat_3[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; + vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; + float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; + float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); + __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); + __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vout_2 = __half22float2(tmp_h2_2[i]); + float2 vout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + float2 vbetta_2 = __half22float2(betta_h2_2[i]); + float2 vbetta_3 = __half22float2(betta_h2_3[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); + xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); + xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + float2 vinp_2 = __half22float2(tmp_h2_2[i]); + float2 vinp_3 = __half22float2(tmp_h2_3[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; + xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; + xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; + float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); + __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); + __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_2[2 * i])); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_3[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; + ((float4 *)inp_grad)[offset + 2] = vtmp_2; + ((float4 *)inp_grad)[offset + 3] = vtmp_3; +} + +/** +Layer norm backword, + compute the gradient of gamma, betta and input. +dbetta = sum(dout, dim=0) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dgamma = sum(xhat * dout, dim=0) +dxhat = dout * gamma +dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) + * rsqrt(var) + +residual_grad, means, betta can be nullptr. +residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln +means and betta are only used to compute xhat, + (means == nullptr) ^ (betta == nullptr) should be true +*/ +template <> +void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, + const float *out_grad, const float *residual_grad, + const float *inp_or_out, const float *gamma, + const float *betta, const float *vars, + const float *means, int batch, int hidden_dim, + cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 4 != 0 || hidden_dim > 4096) { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); +} + +template <> +void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, + __half *inp_grad, const __half *out_grad, + const __half *residual_grad, const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, int batch, + int hidden_dim, cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<__half><<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 8 != 0) { + throw std::runtime_error("hidden_dim % 8 != 0"); + } + hidden_dim >>= 3; + + if (hidden_dim * 8 <= 8192) { + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { + hidden_dim >>= 1; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x2<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x4<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + } +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu index 98af433fe397..3862a699d3c3 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu @@ -1,365 +1,365 @@ -#include -#include - -#include -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float EPSILON = 1e-8f; - -/** -@brief: softmax_kernel -Softmax forward kernel for - enc-self-attn, dec-self-attn, encdec-attn - -@thread -gridDim.x = dynamic -gridDim.y = batch_size -gridDim.z = nhead -blockDim.x = from_len - -@param -inp: [batch_size, nhead, from_len, to_len], softmax input. -attn_mask: [batch_size, to_len], padding tokens are -inf, - non padding tokens are 0. - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template -__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // block reduce max - blockReduce(l_max); - // write shared - __shared__ float s_max[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_max[i] = l_max[i]; - } - } - __syncthreads(); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - s_max[i]); - l_sum[i] += val[i][j]; - } - } - // block reduce sum - blockReduce(l_sum); - // write shared - __shared__ float s_sum[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - } - } - __syncthreads(); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * s_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -template -__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // warp reduce max - warpReduce(l_max); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - l_max[i]); - l_sum[i] += val[i][j]; - } - } - // warp reduce sum - warpReduce(l_sum); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * l_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -/* - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template <> -void launch_attn_softmax(float *inp, const float *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 16; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 32; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 64; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -template <> -void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<__half, 32, 1><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<__half, 32, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 8; - ker_attn_softmax<__half, 64, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 16; - ker_attn_softmax<__half, 128, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 32; - ker_attn_softmax<__half, 256, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -/** -@brief: ker_attn_softmax_bw -Softmax backward in self attention. - -@thread -gridDim.x = batch_size * nhead * seq_len / warps_per_block -blockDim.x = WARP_SIZE -blockDim.y = warps_per_block - -@param -grad: [batch_size, nhead, seq_len, seq_len], output grad. -output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. -*/ -template -__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - inp += offset; - - T grad_reg[ITERATIONS]; - T inp_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - inp_reg[i] = inp[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)inp_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); - } -} - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream) { - const int warps_per_block = 4; - // rows = batch_size * nhead * from_len - dim3 grid_dim(rows / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (softmax_len <= 32) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 64) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 128) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 256) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 384) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 512) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 768) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 1024) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 2048) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else - throw std::runtime_error( - std::string( - "Special sequence length found in softmax backward, seq_len: ") + - std::to_string(softmax_len)); -} - -template void launch_attn_softmax_bw<__half>(__half *out_grad, - const __half *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); -template void launch_attn_softmax_bw(float *out_grad, - const float *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); +#include +#include + +#include +#include + +#include "block_reduce.h" +#include "kernels.h" + +namespace cg = cooperative_groups; +const float EPSILON = 1e-8f; + +/** +@brief: softmax_kernel +Softmax forward kernel for + enc-self-attn, dec-self-attn, encdec-attn + +@thread +gridDim.x = dynamic +gridDim.y = batch_size +gridDim.z = nhead +blockDim.x = from_len + +@param +inp: [batch_size, nhead, from_len, to_len], softmax input. +attn_mask: [batch_size, to_len], padding tokens are -inf, + non padding tokens are 0. + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template +__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // block reduce max + blockReduce(l_max); + // write shared + __shared__ float s_max[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_max[i] = l_max[i]; + } + } + __syncthreads(); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - s_max[i]); + l_sum[i] += val[i][j]; + } + } + // block reduce sum + blockReduce(l_sum); + // write shared + __shared__ float s_sum[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + } + } + __syncthreads(); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * s_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +template +__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // warp reduce max + warpReduce(l_max); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - l_max[i]); + l_sum[i] += val[i][j]; + } + } + // warp reduce sum + warpReduce(l_sum); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * l_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +/* + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template <> +void launch_attn_softmax(float *inp, const float *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 16; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 32; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 64; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +template <> +void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<__half, 32, 1><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<__half, 32, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 8; + ker_attn_softmax<__half, 64, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 16; + ker_attn_softmax<__half, 128, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 32; + ker_attn_softmax<__half, 256, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +/** +@brief: ker_attn_softmax_bw +Softmax backward in self attention. + +@thread +gridDim.x = batch_size * nhead * seq_len / warps_per_block +blockDim.x = WARP_SIZE +blockDim.y = warps_per_block + +@param +grad: [batch_size, nhead, seq_len, seq_len], output grad. +output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. +*/ +template +__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + inp += offset; + + T grad_reg[ITERATIONS]; + T inp_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + inp_reg[i] = inp[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)inp_reg[i]; + } + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); + } +} + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream) { + const int warps_per_block = 4; + // rows = batch_size * nhead * from_len + dim3 grid_dim(rows / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (softmax_len <= 32) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 64) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 128) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 256) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 384) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 512) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 768) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 1024) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 2048) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else + throw std::runtime_error( + std::string( + "Special sequence length found in softmax backward, seq_len: ") + + std::to_string(softmax_len)); +} + +template void launch_attn_softmax_bw<__half>(__half *out_grad, + const __half *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); +template void launch_attn_softmax_bw(float *out_grad, + const float *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu index d03084b22e12..04de3c092ee0 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu @@ -1,312 +1,314 @@ -#include -#include -#include - -#include "kernels.h" - -using namespace cub; - -/** -@brief: transform_0213 -Split the attention heads and reshape input -during backward progress of encoder self-attention - -@thread -gridDim.x = batch_size -gridDim.y = seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -input: [batch_size, seq_len, hidden_dim] -output: [batch_size, nhead, seq_len, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -*/ - -template -__global__ void transform_0213(T *output, const T *input, int hidden_dim, - int head_dim); - -template <> -__global__ void transform_0213(float *output, const float *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -template <> -__global__ void transform_0213<__half>(__half *output, const __half *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -// [b, s, h] -> [b, nh, s, ad] -template <> -void launch_transform_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213 - <<>>(output, input, hidden_dim, head_dim); -} - -template <> -void launch_transform_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213<__half> - <<>>(output, input, hidden_dim, head_dim); -} - -/** -@brief: bias_add_transform_20314 -Add bias to input, transform from -[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] - -@thread -gridDim.x = dim_0 -gridDim.y = dim_1 -gridDim.z = dim_2 -blockDim.x = min(dim_3 * dim_4, MAX_THREADS) - -@param -input: [dim_0, dim_1, dim_2, dim_3, dim_4] -bias: [dim_2, dim_3, dim_4] -output: [dim_2, dim_0, dim_3, dim_1, dim_4] -*/ -template -__global__ void bias_add_transform_20314(T *output, const T *input, - const T *bias, int dim_3, int dim_4); - -template <> -__global__ void -bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_3, int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - vres4.x = vqkv4.x + vbias4.x; - vres4.y = vqkv4.y + vbias4.y; - vres4.z = vqkv4.z + vbias4.z; - vres4.w = vqkv4.w + vbias4.w; - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -template <> -__global__ void -bias_add_transform_20314<__half>(__half *output, const __half *input, - const __half *bias, int dim_3, int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); - __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); - __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); - h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); - h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); - h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template <> -void launch_bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 2; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314 - <<>>(output, input, bias, dim_3, dim_4); -} - -template <> -void launch_bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 3; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314<__half> - <<>>(output, input, bias, dim_3, dim_4); -} - -/** -@brief: transform4d_0213 -Reshape the input matrix to merge the heads - -@thread -gridDim.x = (num_all + max_block_thread - 1) / max_block_thread -blockDim.x = max_block_thread - -@param -input: [trans_count, batch_size, nhead, seq_len, head_dim] -output: [batch_size, seq_len, trans_count, nhead, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -trans_count: 1 or 3, the count of matrice need to be transformed -*/ -template -__global__ void transform4d_0213(T *output, const T *input, int batch_size, - int seq_len, int trans_count, int nhead, - int head_dim, int num_all) { - int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset >= num_all) { - return; - } - int trans_id, batch_id, head_id, token_id, dim_id; - decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, - &batch_id, &head_id, &token_id, &dim_id); - // [b, s, tc, nh, ad] - int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, - seq_len, trans_count, nhead, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - res4[trg_offset] = input4[offset]; -} - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template <> -void launch_transform4d_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} - -template <> -void launch_transform4d_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, - int hidden_dim, int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<__half><<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} +#include +#include +#include + +#include "kernels.h" + +using namespace cub; + +/** +@brief: transform_0213 +Split the attention heads and reshape input +during backward progress of encoder self-attention + +@thread +gridDim.x = batch_size +gridDim.y = seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +input: [batch_size, seq_len, hidden_dim] +output: [batch_size, nhead, seq_len, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +*/ + +template +__global__ void transform_0213(T *output, const T *input, int hidden_dim, + int head_dim); + +template <> +__global__ void transform_0213(float *output, const float *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +template <> +__global__ void transform_0213<__half>(__half *output, const __half *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +// [b, s, h] -> [b, nh, s, ad] +template <> +void launch_transform_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213 + <<>>(output, input, hidden_dim, head_dim); +} + +template <> +void launch_transform_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213<__half> + <<>>(output, input, hidden_dim, head_dim); +} + +/** +@brief: bias_add_transform_20314 +Add bias to input, transform from +[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] + +@thread +gridDim.x = dim_0 +gridDim.y = dim_1 +gridDim.z = dim_2 +blockDim.x = min(dim_3 * dim_4, MAX_THREADS) + +@param +input: [dim_0, dim_1, dim_2, dim_3, dim_4] +bias: [dim_2, dim_3, dim_4] +output: [dim_2, dim_0, dim_3, dim_1, dim_4] +*/ +template +__global__ void bias_add_transform_20314(T *output, const T *input, + const T *bias, int dim_3, int dim_4); + +template <> +__global__ void bias_add_transform_20314(float *output, + const float *input, + const float *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + vres4.x = vqkv4.x + vbias4.x; + vres4.y = vqkv4.y + vbias4.y; + vres4.z = vqkv4.z + vbias4.z; + vres4.w = vqkv4.w + vbias4.w; + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +template <> +__global__ void bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); + __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); + __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); + h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); + h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); + h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template <> +void launch_bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 2; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314 + <<>>(output, input, bias, dim_3, dim_4); +} + +template <> +void launch_bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 3; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314<__half> + <<>>(output, input, bias, dim_3, dim_4); +} + +/** +@brief: transform4d_0213 +Reshape the input matrix to merge the heads + +@thread +gridDim.x = (num_all + max_block_thread - 1) / max_block_thread +blockDim.x = max_block_thread + +@param +input: [trans_count, batch_size, nhead, seq_len, head_dim] +output: [batch_size, seq_len, trans_count, nhead, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +trans_count: 1 or 3, the count of matrice need to be transformed +*/ +template +__global__ void transform4d_0213(T *output, const T *input, int batch_size, + int seq_len, int trans_count, int nhead, + int head_dim, int num_all) { + int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset >= num_all) { + return; + } + int trans_id, batch_id, head_id, token_id, dim_id; + decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, + &batch_id, &head_id, &token_id, &dim_id); + // [b, s, tc, nh, ad] + int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, + seq_len, trans_count, nhead, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + res4[trg_offset] = input4[offset]; +} + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template <> +void launch_transform4d_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} + +template <> +void launch_transform4d_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, + int hidden_dim, int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<__half><<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp index 4690277e63db..15a07bb0c7ac 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -138,4 +138,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu index ad7066bbd9df..72b84d6ca40f 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu @@ -680,4 +680,4 @@ void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, grad_input->DATA_PTR(), gamma != NULL ? grad_gamma->DATA_PTR() : NULL, gamma != NULL ? grad_beta->DATA_PTR() : NULL);) -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp index 61c8a725052f..8c0b89eb06d1 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -1,97 +1,97 @@ -#include - -torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx); - -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); - -#define CHECK_CUDA(x) \ - TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -torch::Tensor moe_dispatch_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, torch::Tensor dest_idx) { - CHECK_INPUT(batch_tokens); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); -} - -torch::Tensor moe_dispatch_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(expert_grad); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); -} - -torch::Tensor moe_combine_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(expert_tokens); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, - dest_idx); -} - -std::vector moe_combine_backward(int s, int e, int c, int h, - torch::Tensor tokens_grad, - torch::Tensor expert_tokens, - torch::Tensor logits, - torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(tokens_grad); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, - logits, mask, dest_idx); -} - -torch::Tensor moe_cumsum(torch::Tensor mask) { - CHECK_INPUT(mask); - return cumsum_sub_one_in_dim0(mask); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); - m.def("dispatch_forward", &moe_dispatch_forward, - "Forward operation in MoE dispatch function"); - m.def("dispatch_backward", &moe_dispatch_backward, - "Backward operation in MoE dispatch function"); - m.def("combine_forward", &moe_combine_forward, - "Combine operation in MoE combine function"); - m.def("combine_backward", &moe_combine_backward, - "Combine operation in MoE combine function"); -} +#include + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor moe_dispatch_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(batch_tokens); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); +} + +torch::Tensor moe_dispatch_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_grad); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); +} + +torch::Tensor moe_combine_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_tokens); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, + dest_idx); +} + +std::vector moe_combine_backward(int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(tokens_grad); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, + logits, mask, dest_idx); +} + +torch::Tensor moe_cumsum(torch::Tensor mask) { + CHECK_INPUT(mask); + return cumsum_sub_one_in_dim0(mask); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); + m.def("dispatch_forward", &moe_dispatch_forward, + "Forward operation in MoE dispatch function"); + m.def("dispatch_backward", &moe_dispatch_backward, + "Backward operation in MoE dispatch function"); + m.def("combine_forward", &moe_combine_forward, + "Combine operation in MoE combine function"); + m.def("combine_backward", &moe_combine_backward, + "Combine operation in MoE combine function"); +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu index 0454377a2fad..66c1e6bd260e 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu @@ -1,659 +1,659 @@ -#include -#include -#include - -#include - -#include "block_reduce.h" - -template -__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - BlockStore(ts_store).Store(dst_row + idx, pack); - } -} - -template -__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, pack); - BlockStore(ts_store).Store(src_row + idx, pack); - } -} - -template -__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - BlockStore(ts_store).Store(dst_row1 + idx, pack); - BlockStore(ts_store).Store(dst_row2 + idx, pack); - } -} - -template -__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack1[pack_size], pack2[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row1 + idx, pack1); - BlockLoad(ts_load).Load(dst_row2 + idx, pack2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack1[i] += pack2[i]; - } - - BlockStore(ts_store).Store(src_row + idx, pack1); - } -} - -template -__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack[i] *= weight; - } - - BlockStore(ts_store).Store(dst_row + idx, pack); - } -} - -template -__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, - T *weight_grad, const T weight, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T grad[pack_size], tokens[pack_size]; - float thread_sum = 0; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, grad); - BlockLoad(ts_load).Load(tks_row + idx, tokens); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_sum += grad[i] * tokens[i]; - grad[i] *= weight; - } - - BlockStore(ts_store).Store(src_row + idx, grad); - } - - blockReduce(&thread_sum); - - if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); -} - -template -__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, - const T weight1, const T weight2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack1[pack_size], pack2[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row1 + idx, pack1); - BlockLoad(ts_load).Load(src_row2 + idx, pack2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; - } - - BlockStore(ts_store).Store(dst_row + idx, pack1); - } -} - -template -__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, - T *tks_row1, T *tks_row2, T *weight_grad1, - T *weight_grad2, const T weight1, - const T weight2, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], - sgrad2[pack_size]; - float thread_sum[2] = {0, 0}; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, grad); - BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); - BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_sum[0] += grad[i] * tokens1[i]; - thread_sum[1] += grad[i] * tokens2[i]; - sgrad1[i] = weight1 * grad[i]; - sgrad2[i] = weight2 * grad[i]; - } - - BlockStore(ts_store).Store(src_row1 + idx, sgrad1); - BlockStore(ts_store).Store(src_row2 + idx, sgrad2); - } - - blockReduce(thread_sum); - - if (threadIdx.x == 0) - *weight_grad1 = static_cast(thread_sum[0]); - else if (threadIdx.x == 1) - *weight_grad2 = static_cast(thread_sum[1]); -} - -// DISPATCH KERNELS -------------------------------- - -template -__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, - const int cols, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_dpch_two_fwd(src_row, dst_row1, dst_row2, - cols); - else if (indicator1 != 0) - moe_dpch_one_fwd(src_row, dst_row1, cols); - else if (indicator2 != 0) - moe_dpch_one_fwd(src_row, dst_row2, cols); - else - return; -} - -template -__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, - const int cols, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_dpch_two_bwd(src_row, dst_row1, dst_row2, - cols); - else if (indicator1 != 0) - moe_dpch_one_bwd(src_row, dst_row1, cols); - else if (indicator2 != 0) - moe_dpch_one_bwd(src_row, dst_row2, cols); - else - return; -} - -template -__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, - int *mask1, int *mask2, int *dest1, - int *dest2, const int h) { - int row = blockIdx.x; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - moe_dpch_fwd_selector( - batch_tokens + (row * h), expert_input + (dest1[row] * h), - expert_input + (dest2[row] * h), h, mask1[row], indicator2); -} - -template -__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, - int *mask2, int *dest1, int *dest2, - const int h) { - int row = blockIdx.x; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - moe_dpch_bwd_selector( - tokens_grad + (row * h), expert_grad + (dest1[row] * h), - expert_grad + (dest2[row] * h), h, mask1[row], indicator2); -} - -// COMBINE KERNELS -------------------------------- - -template -__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, - const int cols, const T weight1, - const T weight2, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_cb_two_fwd(src_row1, src_row2, dst_row, - weight1, weight2, cols); - else if (indicator1 != 0) - moe_cb_one_fwd(src_row1, dst_row, weight1, cols); - else if (indicator2 != 0) - moe_cb_one_fwd(src_row2, dst_row, weight2, cols); - else - return; -} - -template -__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, - const int cols, T *tks_row1, T *tks_row2, - T *wt_grad1, T *wt_grad2, const T weight1, - const T weight2, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_cb_two_bwd(src_row1, src_row2, dst_row, - tks_row1, tks_row2, wt_grad1, - wt_grad2, weight1, weight2, cols); - else if (indicator1 != 0) - moe_cb_one_bwd(src_row1, dst_row, tks_row1, - wt_grad1, weight1, cols); - else if (indicator2 != 0) - moe_cb_one_bwd(src_row2, dst_row, tks_row2, - wt_grad2, weight2, cols); - else - return; -} - -template -__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, - T *logits, int *mask1, int *mask2, int *dest1, - int *dest2, const int e, const int c, - const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - T *row_log = logits + (row * e); - moe_cb_fwd_selector( - expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), - combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], - indicator2); -} - -template -__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, - T *logits, T *logits_grad, int *mask1, - int *mask2, int *dest1, int *dest2, - const int e, const int c, const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); - moe_cb_bwd_selector( - expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), - tokens_grad + (row * h), h, tks + (dest1[row] * h), - tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], - row_log[eid2], mask1[row], indicator2); -} - -// CUMSUM KERNEL -------------------------------- - -template -__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, - const int e) { - assert(s % pack_size == 0); - constexpr int bpack_size = block_size * pack_size; - int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; - __shared__ int temp[block_size + 1]; - int pack[pack_size]; - - for (int idx = 0; idx < s; idx += bpack_size) { - int offset = 1; - - if (idx + tps < s) { - temp[tid] = inputs[tps * e + bid]; -#pragma unroll - for (int i = 1; i < pack_size; ++i) { - pack[i] = inputs[(tps + i) * e + bid]; - } -#pragma unroll - for (int i = 1; i < pack_size; ++i) { - temp[tid] += pack[i]; - } - } - - for (int i = block_size >> 1; i > 0; i >>= 1) { - __syncthreads(); - if (tid < i) { - int j = offset * (2 * tid + 1) - 1; - temp[j + offset] += temp[j]; - } - offset <<= 1; - } - - if (tid == 0) { - temp[block_size] = temp[block_size - 1]; - temp[block_size - 1] = 0; - } - - for (int i = 1; i < block_size; i <<= 1) { - offset >>= 1; - __syncthreads(); - if (tid < i) { - int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; - temp[j] = temp[k]; - temp[k] += ts; - } - } - __syncthreads(); - - if (tid == 0) temp[0] = temp[block_size]; - __syncthreads(); - - if (idx + tps < s) { - temp[tid + 1] += last_sum; -#pragma unroll - for (int i = pack_size - 1; i > 0; --i) { - outputs[(tps + i) * e + bid] = temp[tid + 1]; - temp[tid + 1] -= pack[i]; - } - outputs[tps * e + bid] = temp[tid + 1]; - } - __syncthreads(); - - last_sum += temp[0]; - inputs += bpack_size * e; - outputs += bpack_size * e; - } -} - -// LAUNCH FUNCTIONS -------------------------------- - -template -void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, - int *mask2, int *dest1, int *dest2, const int s, - const int h) { - if (h < 256) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 512) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 1024) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 2048) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); -} - -template -void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, - int *dest1, int *dest2, const int s, const int h) { - if (h < 256) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 512) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 1024) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 2048) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); -} - -template -void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, - int *mask1, int *mask2, int *dest1, int *dest2, - const int s, const int e, const int c, const int h) { - if (h < 256) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 512) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 1024) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 2048) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, - dest2, e, c, h); -} - -template -void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, - T *logits_grad, int *mask1, int *mask2, int *dest1, - int *dest2, const int s, const int e, const int c, - const int h) { - if (h < 256) - moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, - logits, logits_grad, mask1, mask2, - dest1, dest2, e, c, h); - else // if (h < 512) - moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, - logits, logits_grad, mask1, mask2, - dest1, dest2, e, c, h); - // else if (h < 1024) - // moe_cb_bwd_kernel<<>> - // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, - // dest1, dest2, e, c, h); - // else - // moe_cb_bwd_kernel<<>> - // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, - // dest1, dest2, e, c, h); -} - -void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { - if (s <= 256) - cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); - else if (s <= 512) - cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); - else if (s <= 1024) - cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); - else if (s <= 2048) - cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); - else - cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); -} - -// API FUNCTIONS -------------------------------- - -#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented yet for specific data type."); \ - } - -torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - auto res = torch::zeros( - {ec, h}, - torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - batch_tokens.scalar_type(), "moe dispatch forward", - moe_dpch_fwd_launch( - batch_tokens.data(), res.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); - - return res; -} - -torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - auto res = torch::zeros( - {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - expert_grad.scalar_type(), "moe dispatch backward", - moe_dpch_bwd_launch( - res.data(), expert_grad.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); - - return res; -} - -torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - assert(expert_tokens.dtype() == logits.dtype()); - - auto res = torch::zeros( - {s, h}, - torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - expert_tokens.scalar_type(), "moe combine forward", - moe_cb_fwd_launch( - expert_tokens.data(), res.data(), - logits.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, - h)); - - return res; -} - -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - assert(tokens_grad.dtype() == expert_tokens.dtype()); - assert(expert_tokens.dtype() == logits.dtype()); - - auto egrad = torch::zeros( - {e * c, h}, - torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), - wgrad = torch::zeros( - {s, e}, torch::dtype(logits.dtype()).device(logits.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - tokens_grad.scalar_type(), "moe combine backward", - moe_cb_bwd_launch( - tokens_grad.data(), egrad.data(), - expert_tokens.data(), logits.data(), - wgrad.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, - h)); - - return {egrad, wgrad}; -} - -torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { - assert(mask.dim() == 2); - assert(mask.dtype() == torch::kInt32); - - const int s = mask.size(0), e = mask.size(1); - auto res = - torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); - cumsum_launch(mask.data(), res.data(), s, e); - - return res; -} +#include +#include +#include + +#include + +#include "block_reduce.h" + +template +__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, pack); + BlockStore(ts_store).Store(src_row + idx, pack); + } +} + +template +__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row1 + idx, pack); + BlockStore(ts_store).Store(dst_row2 + idx, pack); + } +} + +template +__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row1 + idx, pack1); + BlockLoad(ts_load).Load(dst_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] += pack2[i]; + } + + BlockStore(ts_store).Store(src_row + idx, pack1); + } +} + +template +__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] *= weight; + } + + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, + T *weight_grad, const T weight, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens[pack_size]; + float thread_sum = 0; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row + idx, tokens); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum += grad[i] * tokens[i]; + grad[i] *= weight; + } + + BlockStore(ts_store).Store(src_row + idx, grad); + } + + blockReduce(&thread_sum); + + if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); +} + +template +__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, + const T weight1, const T weight2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row1 + idx, pack1); + BlockLoad(ts_load).Load(src_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; + } + + BlockStore(ts_store).Store(dst_row + idx, pack1); + } +} + +template +__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, + T *tks_row1, T *tks_row2, T *weight_grad1, + T *weight_grad2, const T weight1, + const T weight2, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], + sgrad2[pack_size]; + float thread_sum[2] = {0, 0}; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); + BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum[0] += grad[i] * tokens1[i]; + thread_sum[1] += grad[i] * tokens2[i]; + sgrad1[i] = weight1 * grad[i]; + sgrad2[i] = weight2 * grad[i]; + } + + BlockStore(ts_store).Store(src_row1 + idx, sgrad1); + BlockStore(ts_store).Store(src_row2 + idx, sgrad2); + } + + blockReduce(thread_sum); + + if (threadIdx.x == 0) + *weight_grad1 = static_cast(thread_sum[0]); + else if (threadIdx.x == 1) + *weight_grad2 = static_cast(thread_sum[1]); +} + +// DISPATCH KERNELS -------------------------------- + +template +__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_fwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_fwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_fwd(src_row, dst_row2, cols); + else + return; +} + +template +__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_bwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_bwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_bwd(src_row, dst_row2, cols); + else + return; +} + +template +__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, + int *mask1, int *mask2, int *dest1, + int *dest2, const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_fwd_selector( + batch_tokens + (row * h), expert_input + (dest1[row] * h), + expert_input + (dest2[row] * h), h, mask1[row], indicator2); +} + +template +__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_bwd_selector( + tokens_grad + (row * h), expert_grad + (dest1[row] * h), + expert_grad + (dest2[row] * h), h, mask1[row], indicator2); +} + +// COMBINE KERNELS -------------------------------- + +template +__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_fwd(src_row1, src_row2, dst_row, + weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_fwd(src_row1, dst_row, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_fwd(src_row2, dst_row, weight2, cols); + else + return; +} + +template +__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, T *tks_row1, T *tks_row2, + T *wt_grad1, T *wt_grad2, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_bwd(src_row1, src_row2, dst_row, + tks_row1, tks_row2, wt_grad1, + wt_grad2, weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_bwd(src_row1, dst_row, tks_row1, + wt_grad1, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_bwd(src_row2, dst_row, tks_row2, + wt_grad2, weight2, cols); + else + return; +} + +template +__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, + T *logits, int *mask1, int *mask2, int *dest1, + int *dest2, const int e, const int c, + const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e); + moe_cb_fwd_selector( + expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), + combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], + indicator2); +} + +template +__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, + T *logits, T *logits_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int e, const int c, const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); + moe_cb_bwd_selector( + expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), + tokens_grad + (row * h), h, tks + (dest1[row] * h), + tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], + row_log[eid2], mask1[row], indicator2); +} + +// CUMSUM KERNEL -------------------------------- + +template +__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, + const int e) { + assert(s % pack_size == 0); + constexpr int bpack_size = block_size * pack_size; + int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; + __shared__ int temp[block_size + 1]; + int pack[pack_size]; + + for (int idx = 0; idx < s; idx += bpack_size) { + int offset = 1; + + if (idx + tps < s) { + temp[tid] = inputs[tps * e + bid]; +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + pack[i] = inputs[(tps + i) * e + bid]; + } +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + temp[tid] += pack[i]; + } + } + + for (int i = block_size >> 1; i > 0; i >>= 1) { + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1; + temp[j + offset] += temp[j]; + } + offset <<= 1; + } + + if (tid == 0) { + temp[block_size] = temp[block_size - 1]; + temp[block_size - 1] = 0; + } + + for (int i = 1; i < block_size; i <<= 1) { + offset >>= 1; + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; + temp[j] = temp[k]; + temp[k] += ts; + } + } + __syncthreads(); + + if (tid == 0) temp[0] = temp[block_size]; + __syncthreads(); + + if (idx + tps < s) { + temp[tid + 1] += last_sum; +#pragma unroll + for (int i = pack_size - 1; i > 0; --i) { + outputs[(tps + i) * e + bid] = temp[tid + 1]; + temp[tid + 1] -= pack[i]; + } + outputs[tps * e + bid] = temp[tid + 1]; + } + __syncthreads(); + + last_sum += temp[0]; + inputs += bpack_size * e; + outputs += bpack_size * e; + } +} + +// LAUNCH FUNCTIONS -------------------------------- + +template +void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, + int *mask2, int *dest1, int *dest2, const int s, + const int h) { + if (h < 256) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); +} + +template +void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, + int *dest1, int *dest2, const int s, const int h) { + if (h < 256) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); +} + +template +void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, + int *mask1, int *mask2, int *dest1, int *dest2, + const int s, const int e, const int c, const int h) { + if (h < 256) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 512) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 1024) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 2048) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, + dest2, e, c, h); +} + +template +void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, + T *logits_grad, int *mask1, int *mask2, int *dest1, + int *dest2, const int s, const int e, const int c, + const int h) { + if (h < 256) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + else // if (h < 512) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + // else if (h < 1024) + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); + // else + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); +} + +void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { + if (s <= 256) + cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); + else if (s <= 512) + cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); + else if (s <= 1024) + cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); + else if (s <= 2048) + cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); + else + cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); +} + +// API FUNCTIONS -------------------------------- + +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type."); \ + } + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {ec, h}, + torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + batch_tokens.scalar_type(), "moe dispatch forward", + moe_dpch_fwd_launch( + batch_tokens.data(), res.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_grad.scalar_type(), "moe dispatch backward", + moe_dpch_bwd_launch( + res.data(), expert_grad.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(expert_tokens.dtype() == logits.dtype()); + + auto res = torch::zeros( + {s, h}, + torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_tokens.scalar_type(), "moe combine forward", + moe_cb_fwd_launch( + expert_tokens.data(), res.data(), + logits.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return res; +} + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(tokens_grad.dtype() == expert_tokens.dtype()); + assert(expert_tokens.dtype() == logits.dtype()); + + auto egrad = torch::zeros( + {e * c, h}, + torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), + wgrad = torch::zeros( + {s, e}, torch::dtype(logits.dtype()).device(logits.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + tokens_grad.scalar_type(), "moe combine backward", + moe_cb_bwd_launch( + tokens_grad.data(), egrad.data(), + expert_tokens.data(), logits.data(), + wgrad.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return {egrad, wgrad}; +} + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { + assert(mask.dim() == 2); + assert(mask.dtype() == torch::kInt32); + + const int s = mask.size(0), e = mask.size(1); + auto res = + torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); + cumsum_launch(mask.data(), res.data(), s, e); + + return res; +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu index 49ab83e8fc81..85f935152f8a 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu @@ -379,4 +379,4 @@ void multi_tensor_norm_out_cuda( norm_type, alpha, beta); return; -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu index 54c4220190d8..63771cf40bcb 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu @@ -351,4 +351,4 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, lr, weight_decay, use_nvlamb);) AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu index 360485dcd02f..2f58a0f16dce 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu @@ -122,4 +122,4 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, AT_CUDA_CHECK(cudaGetLastError()); // AT_CUDA_CHECK(cudaDeviceSynchronize()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu index 35f2c9b4ed15..7f48dbd5d497 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu @@ -164,4 +164,4 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, } AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp index 4ae3c853ca5e..8c2982b0cff9 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp @@ -3,82 +3,68 @@ #include #include + #include namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads); + +torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); return fwd_cuda(input, mask, scale_factor); } -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, float scale_factor) { AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, + attn_heads); } -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); + m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); + &multihead_attn::fused_softmax::scaled_masked_softmax:: + get_batch_per_block, + "Return Batch per block size."); } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h index 1583030b8235..d3e6f04e6093 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h @@ -4,12 +4,12 @@ #pragma once #include +#include #include +#include + #include #include -#include -#include -#include namespace { @@ -17,37 +17,53 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); +} int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -55,438 +71,468 @@ struct Max { }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Explicit masking + */ +template __global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } + output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale, + int micro_batch_size, int element_count, int pad_batches) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = + (blockDim.y * + (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * + WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; } + copy_vector( + dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } } + } } -template +template __global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = + first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; } + } } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector( + gradInput + i * element_count + it * WARP_SIZE, out); + } } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; } -} // end of anonymous namespace -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ +template +void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads, + int pad_batches) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + // use 128 threads per block to maximimize gpu utilization constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } + TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; } + } } -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } +template +void dispatch_scaled_masked_softmax_backward(output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; } + } } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h index 3af487f9de0f..54c8e9133a1b 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h @@ -4,11 +4,12 @@ #pragma once #include +#include #include +#include + #include #include -#include -#include namespace { @@ -16,53 +17,78 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} + template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); +} template __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -70,431 +96,505 @@ struct Max { }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Implicit time (diagonal masking) */ -template +template __global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size, + int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = + (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector( + temp_data, src + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } } + copy_vector( + dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector( + dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } } + } } -template +template __global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; + } } + } } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector( + gradInput + i * element_count * stride + it * WARP_SIZE, out); + } } + } } -} // end of anonymous namespace +} // end of anonymous namespace -template +template void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } + output_t *dst, const input_t *src, const input_t scale, + int softmax_elements, int softmax_elements_stride, int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + default: + break; } + } } -template +template void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } + output_t *grad_input, input_t *grad, const input_t *output, + const acc_t scale, int softmax_elements, int softmax_elements_stride, + int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; } + } } diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index 40355a41ed0d..c7d2a3a45022 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -18,7 +18,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input, weight, bias, normalized_shape, eps): @@ -30,7 +29,6 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): global layer_norm if layer_norm is None: - layer_norm = LayerNormBuilder().load() output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.layernorm_op = layer_norm @@ -43,17 +41,14 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = layer_norm.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + grad_input, grad_weight, grad_bias = layer_norm.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) return grad_input, grad_weight, grad_bias, None, None class MixedFusedLayerNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): super(MixedFusedLayerNorm, self).__init__() @@ -66,13 +61,11 @@ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None): self.reset_parameters() def reset_parameters(self): - init.ones_(self.weight) init.zeros_(self.bias) def forward(self, input): - return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) def __repr__(self): - return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' + return f"MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})" diff --git a/colossalai/kernel/cuda_native/mha/__init__.py b/colossalai/kernel/cuda_native/mha/__init__.py index 21fddd512957..cad36e598d14 100644 --- a/colossalai/kernel/cuda_native/mha/__init__.py +++ b/colossalai/kernel/cuda_native/mha/__init__.py @@ -1,3 +1,3 @@ from .mha import ColoAttention -__all__ = ['ColoAttention'] +__all__ = ["ColoAttention"] diff --git a/colossalai/kernel/cuda_native/mha/flash_attn_2.py b/colossalai/kernel/cuda_native/mha/flash_attn_2.py index 6a8d74f70c1d..9ee83915b1b4 100644 --- a/colossalai/kernel/cuda_native/mha/flash_attn_2.py +++ b/colossalai/kernel/cuda_native/mha/flash_attn_2.py @@ -8,7 +8,7 @@ def is_ampere_or_better_gpu(): if torch.cuda.is_available(): device = torch.device("cuda") properties = torch.cuda.get_device_properties(device) - if properties.major >= 8: # Ampere GPUs or newer + if properties.major >= 8: # Ampere GPUs or newer return True return False @@ -18,30 +18,33 @@ def is_ampere_or_better_gpu(): if is_ampere_or_better_gpu(): HAS_FLASH_ATTN = True else: - warnings.warn('FlashAttention only supports Ampere GPUs or newer.') + warnings.warn("FlashAttention only supports Ampere GPUs or newer.") HAS_FLASH_ATTN = False try: from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + HAS_FLASH_ATTN = True except ImportError: - warnings.warn('please install flash_attn from https://github.com/HazyResearch/flash-attention') + warnings.warn("please install flash_attn from https://github.com/HazyResearch/flash-attention") HAS_FLASH_ATTN = False if HAS_FLASH_ATTN: - from einops import rearrange + pass from .utils import SeqLenInfo - def flash_attention(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0., - scale: float = None, - causal: bool = False, - padded: bool = False): + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): """ Arguments: q: (batch, q_seqlen, nheads, headdim) @@ -60,9 +63,18 @@ def flash_attention(q: torch.Tensor, if seq_len_info_kv == None: seq_len_info_kv = seq_len_info_q - attn_out = flash_attn_varlen_func(q, k, v, seq_len_info_q.cu_seqlens, seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, seq_len_info_kv.max_seqlen, dropout_p, scale, - causal) + attn_out = flash_attn_varlen_func( + q, + k, + v, + seq_len_info_q.cu_seqlens, + seq_len_info_kv.cu_seqlens, + seq_len_info_q.max_seqlen, + seq_len_info_kv.max_seqlen, + dropout_p, + scale, + causal, + ) else: attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) return attn_out diff --git a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py index 8a898080877c..649e74d61bab 100644 --- a/colossalai/kernel/cuda_native/mha/mem_eff_attn.py +++ b/colossalai/kernel/cuda_native/mha/mem_eff_attn.py @@ -9,9 +9,10 @@ LowerTriangularMask, LowerTriangularMaskWithTensorBias, ) + HAS_MEM_EFF_ATTN = True except ImportError: - warnings.warn('please install xformers from https://github.com/facebookresearch/xformers') + warnings.warn("please install xformers from https://github.com/facebookresearch/xformers") HAS_MEM_EFF_ATTN = False if HAS_MEM_EFF_ATTN: @@ -29,30 +30,30 @@ for op in MemoryEfficientAttentionCutlassOp: allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - def mem_eff_attention(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0., - scale: float = None, - causal: bool = False, - padded: bool = False): - + def mem_eff_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: SeqLenInfo, + seq_len_info_kv: SeqLenInfo, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): attn_bias = None - if padded: # bert style + if padded: # bert style if not causal: attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) else: attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style + elif causal: # gpt style attn_bias = LowerTriangularMask() - if bias is not None: # alibi / relative position embedding + if bias is not None: # alibi / relative position embedding assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, \ - "attention with bias is only supported for causal attention so far." + assert causal, "attention with bias is only supported for causal attention so far." attn_bias = attn_bias.add_bias(bias) if padded: diff --git a/colossalai/kernel/cuda_native/mha/mha.py b/colossalai/kernel/cuda_native/mha/mha.py index 8f449a138c51..1c778439d33f 100644 --- a/colossalai/kernel/cuda_native/mha/mha.py +++ b/colossalai/kernel/cuda_native/mha/mha.py @@ -2,7 +2,6 @@ from typing import Optional import torch -import torch.nn.functional as F from einops import rearrange from ..scaled_softmax import AttnMaskType @@ -17,11 +16,11 @@ class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): super().__init__() - assert embed_dim % num_heads == 0, \ - f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." + assert ( + embed_dim % num_heads == 0 + ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." if scale is not None: self.scale = scale else: @@ -39,14 +38,15 @@ def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor: return Repad.apply(tensor, indices, batch_size, seq_len) - def forward(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - attn_mask_type: Optional[AttnMaskType] = None, - bias: Optional[torch.Tensor] = None): - + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + attn_mask_type: Optional[AttnMaskType] = None, + bias: Optional[torch.Tensor] = None, + ): attn = None if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None: attn = flash_attention @@ -62,18 +62,21 @@ def forward(self, seq_len_info_kv = None if padded: # bert style, unpad process - assert attn_mask is not None, \ - f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." - assert attn_mask.dim() == 2, \ - "attention mask is supposed to have shape (batch_size, seq_len), " + \ - f"but got {attn_mask.dim()} dimensions." + assert ( + attn_mask is not None + ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." + assert attn_mask.dim() == 2, ( + "attention mask is supposed to have shape (batch_size, seq_len), " + + f"but got {attn_mask.dim()} dimensions." + ) # bert style if tgt_len == src_len: seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) if batch_size > 1: - query, key, value = self.unpad(torch.stack([query, key, value], dim=2), - seq_len_info_q.indices).unbind(dim=1) + query, key, value = self.unpad( + torch.stack([query, key, value], dim=2), seq_len_info_q.indices + ).unbind(dim=1) else: query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) seq_len_info_kv = seq_len_info_q @@ -82,26 +85,29 @@ def forward(self, seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device) if batch_size > 1: query = rearrange(query, "b s ... -> c (b s) ...", c=1) - key, value = self.unpad(torch.stack([query, key, value], dim=2), - seq_len_info_kv.indices).unbind(dim=1) + key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind( + dim=1 + ) else: query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1) - out = attn(query, - key, - value, - seq_len_info_q, - seq_len_info_kv, - dropout_p=self.dropout, - scale=self.scale, - causal=causal, - padded=padded) + out = attn( + query, + key, + value, + seq_len_info_q, + seq_len_info_kv, + dropout_p=self.dropout, + scale=self.scale, + causal=causal, + padded=padded, + ) # repad if padded: if batch_size > 1: out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len) - out = rearrange(out, '(b s) h d -> b s h d', b=batch_size) + out = rearrange(out, "(b s) h d -> b s h d", b=batch_size) - out = rearrange(out, 'b s h d -> b s (h d)') + out = rearrange(out, "b s h d -> b s (h d)") return out diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py index e3e431fa7e99..fe31921b961b 100644 --- a/colossalai/kernel/cuda_native/mha/utils.py +++ b/colossalai/kernel/cuda_native/mha/utils.py @@ -20,18 +20,18 @@ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): # [b, s, ...] assert tensor.ndim >= 3 ctx.bsz = tensor.shape[0] - out = rearrange(tensor, 'b s ... -> (b s) ...') + out = rearrange(tensor, "b s ... -> (b s) ...") ctx.shape = out.shape # [ntokens, ...] return out[indices] @staticmethod def backward(ctx, grad_output): - indices, = ctx.saved_tensors + (indices,) = ctx.saved_tensors # [ntokens, ...] grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) grad[indices] = grad_output - grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz) + grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) # [b, s, ...] return grad, None @@ -54,7 +54,7 @@ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, s @staticmethod def backward(ctx, grad_output): - indices, = ctx.saved_tensors + (indices,) = ctx.saved_tensors # [b*s, ...] grad = grad_output[indices] # [ntokens, ...] diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py index 69246f2f3854..87afc1862847 100644 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -36,34 +36,64 @@ def calc_offset(sizes): @dataclass class Config: - max_batch_tokens: int # max batch token numbers - max_seq_len: int # max sequence length - hidden_size: int # size of transformer hidden layers - nhead: int # number of heads in attention - attn_prob_dropout_ratio: float # attention score dropout ratio - hidden_dropout_ratio: float # dropout ration before residual - norm_first: bool # norm_first - fp16: bool # fp16 precision + max_batch_tokens: int # max batch token numbers + max_seq_len: int # max sequence length + hidden_size: int # size of transformer hidden layers + nhead: int # number of heads in attention + attn_prob_dropout_ratio: float # attention score dropout ratio + hidden_dropout_ratio: float # dropout ration before residual + norm_first: bool # norm_first + fp16: bool # fp16 precision class MultiHeadAttention1DFunc(Function): - @staticmethod - def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, - norm_bias, config): + def forward( + ctx, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + config, + ): cuda_module = colossal_multihead_attention - forward_func = (cuda_module.multihead_attention_fw_fp16 - if config.fp16 else cuda_module.multihead_attention_fw_fp32) + forward_func = ( + cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32 + ) if config.fp16: input = input.to(torch.half) input_mask = input_mask.to(torch.half) - (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, - out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first) + (output,) = forward_func( + config.layer_id, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + config.training, + config.norm_first, + ) if config.is_grad_enabled and config.training: - ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, - out_proj_bias, norm_weight, norm_bias) + ctx.save_for_backward( + output, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + ) ctx.config = config return output @@ -72,11 +102,21 @@ def backward(ctx, grad_output): assert ctx.config.training cuda_module = colossal_multihead_attention - backward_func = (cuda_module.multihead_attention_bw_fp16 - if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32) + backward_func = ( + cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32 + ) - output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \ - out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors + ( + output, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + ) = ctx.saved_tensors grad_input = None grad_in_proj_weight = None @@ -91,13 +131,39 @@ def backward(ctx, grad_output): output = output.to(torch.half) input = input.to(torch.half) input_mask = input_mask.to(torch.half) - grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \ - grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func( - ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, - in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) + ( + grad_input, + grad_in_proj_weight, + grad_in_proj_bias, + grad_out_proj_weight, + grad_out_proj_bias, + grad_norm_weight, + grad_norm_bias, + ) = backward_func( + ctx.config.layer_id, + grad_output, + output, + input, + input_mask, + in_proj_weight, + in_proj_bias, + out_proj_weight, + out_proj_bias, + norm_weight, + norm_bias, + ) - return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, - grad_norm_weight, grad_norm_bias, None) + return ( + grad_input, + None, + grad_in_proj_weight, + grad_in_proj_bias, + grad_out_proj_weight, + grad_out_proj_bias, + grad_norm_weight, + grad_norm_bias, + None, + ) class MultiHeadAttention(nn.Module): @@ -122,8 +188,9 @@ class MultiHeadAttention(nn.Module): def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): super(MultiHeadAttention, self).__init__() - self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, - fp16) + self.config = Config( + batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16 + ) check_config(self.config) self.pg = pg self.pg_size = 1 @@ -136,13 +203,17 @@ def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, nor global colossal_multihead_attention if colossal_multihead_attention is None: from colossalai.kernel.op_builder import MultiHeadAttnBuilder + multihead_attention = MultiHeadAttnBuilder().load() colossal_multihead_attention = multihead_attention # create the layer in cuda kernels. cuda_module = colossal_multihead_attention - create_layer_func = (cuda_module.create_multihead_attention_fp16 - if self.config.fp16 else cuda_module.create_multihead_attention_fp32) + create_layer_func = ( + cuda_module.create_multihead_attention_fp16 + if self.config.fp16 + else cuda_module.create_multihead_attention_fp32 + ) create_layer_func( self.config.layer_id, @@ -204,13 +275,15 @@ def reset_parameters(self): with torch.no_grad(): self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[:, - int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size), :]) + attn_qkvw_global.view(3, hs, hs)[ + :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), : + ] + ) self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[:, - int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size)]) + attn_qkvb_global.view(3, hs)[ + :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size) + ] + ) attn_ow_global = torch.empty(hs, hs) nn.init.xavier_uniform_(attn_ow_global, 1.0) @@ -218,9 +291,9 @@ def reset_parameters(self): torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) attn_ow_global = attn_ow_global.cpu() with torch.no_grad(): - self.out_proj_weight.copy_(attn_ow_global[:, - int(hs * rank_in_pg / - self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)]) + self.out_proj_weight.copy_( + attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)] + ) else: attn_qkvw = self.in_proj_weight.view(-1, hs) @@ -238,7 +311,7 @@ def forward(self, hidden_states, encoder_padding_mask): self.config.training = self.training self.config.is_grad_enabled = torch.is_grad_enabled() hidden_states = hidden_states.contiguous() - encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()) + encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous() bs, sl, dim = hidden_states.size() if bs * sl > self.config.max_batch_tokens: @@ -250,8 +323,16 @@ def forward(self, hidden_states, encoder_padding_mask): else: assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) - output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight, - self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, - self.norm_weight, self.norm_bias, self.config) + output = MultiHeadAttention1DFunc.apply( + hidden_states, + encoder_padding_mask, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.norm_weight, + self.norm_bias, + self.config, + ) return output.to(self.precision) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index 41cd4b20faa1..26a5bce16d5c 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -108,15 +108,16 @@ def __init__( super(FusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 - assert not (self.input_in_fp16 - and self.input_in_bf16), "both fp16 and bf16 flags cannot be active at the same time." + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale - assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" def forward(self, input, mask): # [b, np, sq, sk] @@ -130,13 +131,14 @@ def forward(self, input, mask): def is_kernel_available(self, mask, b, np, sq, sk): attn_batches = b * np - if (self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and mask is not None # mask tensor must not be None - and 16 < sk <= 2048 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): if 0 <= sk <= 2048: batch_per_block = self.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/kernel/jit/__init__.py b/colossalai/kernel/jit/__init__.py index 57b8fb7b2e99..67a147cd581c 100644 --- a/colossalai/kernel/jit/__init__.py +++ b/colossalai/kernel/jit/__init__.py @@ -1,8 +1,10 @@ -from .option import set_jit_fusion_options -from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from .bias_dropout_add import bias_dropout_add_fused_inference, bias_dropout_add_fused_train from .bias_gelu import bias_gelu_impl +from .option import set_jit_fusion_options __all__ = [ - "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", - "set_jit_fusion_options" + "bias_dropout_add_fused_train", + "bias_dropout_add_fused_inference", + "bias_gelu_impl", + "set_jit_fusion_options", ] diff --git a/colossalai/kernel/jit/bias_dropout_add.py b/colossalai/kernel/jit/bias_dropout_add.py index 32965c1ebd69..e046ee2964af 100644 --- a/colossalai/kernel/jit/bias_dropout_add.py +++ b/colossalai/kernel/jit/bias_dropout_add.py @@ -1,5 +1,4 @@ import torch -from torch import Tensor def bias_dropout_add(x, bias, residual, prob, training): @@ -10,16 +9,14 @@ def bias_dropout_add(x, bias, residual, prob, training): @torch.jit.script -def bias_dropout_add_fused_train(x: torch.Tensor, - bias: torch.Tensor, - residual: torch.Tensor, - prob: float) -> torch.Tensor: +def bias_dropout_add_fused_train( + x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float +) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script -def bias_dropout_add_fused_inference(x: torch.Tensor, - bias: torch.Tensor, - residual: torch.Tensor, - prob: float) -> torch.Tensor: +def bias_dropout_add_fused_inference( + x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float +) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) diff --git a/colossalai/kernel/jit/bias_gelu.py b/colossalai/kernel/jit/bias_gelu.py index 33b4ac32b044..5fa0d07015be 100644 --- a/colossalai/kernel/jit/bias_gelu.py +++ b/colossalai/kernel/jit/bias_gelu.py @@ -29,7 +29,6 @@ def bias_gelu_back(g, bias, y): class GeLUFunction(torch.autograd.Function): - @staticmethod # bias is an optional argument def forward(ctx, input, bias): diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index 8eb4e0c880a0..8bebad894ca4 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -10,15 +10,14 @@ def set_jit_fusion_options(): - """Set PyTorch JIT layer fusion options. - """ + """Set PyTorch JIT layer fusion options.""" # LSG: the latest pytorch and CUDA versions may not support # the following jit settings global JIT_OPTIONS_SET if JIT_OPTIONS_SET == False: # flags required to enable jit fusion kernels - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): # nvfuser torch._C._jit_set_profiling_executor(True) @@ -38,12 +37,14 @@ def set_jit_fusion_options(): JIT_OPTIONS_SET = True -def warmup_jit_fusion(batch_size: int, - hidden_size: int, - seq_length: int = 512, - vocab_size: int = 32768, - dtype: torch.dtype = torch.float32): - """ Compile JIT functions before the main training steps """ +def warmup_jit_fusion( + batch_size: int, + hidden_size: int, + seq_length: int = 512, + vocab_size: int = 32768, + dtype: torch.dtype = torch.float32, +): + """Compile JIT functions before the main training steps""" embed = Embedding(vocab_size, hidden_size).to(get_current_device()) linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 75812db036a9..bc68a07e6fba 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -1,5 +1,6 @@ try: import triton + HAS_TRITON = True from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd @@ -11,8 +12,14 @@ from .token_attention_kernel import token_attention_fwd __all__ = [ - "llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward", - "copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd" + "llama_context_attn_fwd", + "bloom_context_attn_fwd", + "softmax", + "layer_norm", + "rmsnorm_forward", + "copy_kv_cache_to_dest", + "rotary_embedding_fwd", + "token_attention_fwd", ] except ImportError: diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 38db2048c6a4..dac95bfb14ae 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -1,8 +1,11 @@ -import torch import math + +import torch + try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -10,28 +13,42 @@ if HAS_TRITON: - ''' - this function is modified from - https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 - ''' + """ + this function is modified from + https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + """ + @triton.jit def _context_flash_attention_kernel( - Q, K, V, sm_scale, - B_Start_Loc, B_Seqlen, - TMP, + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, alibi_ptr, Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_tmp_b, stride_tmp_h, stride_tmp_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, # suggtest set-up 64, 128, 256, 512 - BLOCK_M: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): - batch_id = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -40,13 +57,18 @@ def _context_flash_attention_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - - # get batch info + + # get batch info cur_batch_seq_len = tl.load(B_Seqlen + batch_id) cur_batch_start_index = tl.load(B_Start_Loc + batch_id) block_start_loc = BLOCK_M * start_m - - load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd @@ -56,7 +78,7 @@ def _context_flash_attention_kernel( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - + if alibi_ptr is not None: alibi_m = tl.load(alibi_ptr + cur_head) @@ -64,8 +86,11 @@ def _context_flash_attention_kernel( for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -95,21 +120,25 @@ def _context_flash_attention_kernel( acc_scale = tl.load(t_ptrs) acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new m_i = m_i_new - - off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return - - + @torch.no_grad() def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): BLOCK = 128 @@ -129,17 +158,31 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) _context_flash_attention_kernel[grid]( - q, k, v, sm_scale, - b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, tmp, alibi, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - tmp.stride(0), tmp.stride(1), tmp.stride(2), - # manually setting this blcok num, we can use tuning config to futher speed-up + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + # manually setting this blcok num, we can use tuning config to futher speed-up BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -147,7 +190,7 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al num_stages=1, ) return - + @torch.no_grad() def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): BLOCK = 128 @@ -166,19 +209,34 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_warps = 4 if Lk <= 64 else 8 # num_warps = 4 _context_flash_attention_kernel[grid]( - q, k, v, sm_scale, b_start_loc, b_seq_len, + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, tmp, None, o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - tmp.stride(0), tmp.stride(1), tmp.stride(2), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, ) - return \ No newline at end of file + return diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index c1eaa8a10ed1..02edcc9a903a 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -3,25 +3,28 @@ try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") if HAS_TRITON: + @triton.jit def _fwd_copy_kv_cache_dest( - kv_cache_ptr, dest_index_ptr, + kv_cache_ptr, + dest_index_ptr, out, - stride_k_bs, - stride_k_h, + stride_k_bs, + stride_k_h, stride_k_d, - stride_o_bs, - stride_o_h, + stride_o_bs, + stride_o_h, stride_o_d, head_num, BLOCK_DMODEL: tl.constexpr, - BLOCK_HEAD: tl.constexpr + BLOCK_HEAD: tl.constexpr, ): cur_index = tl.program_id(0) offs_h = tl.arange(0, BLOCK_HEAD) @@ -31,15 +34,14 @@ def _fwd_copy_kv_cache_dest( cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets - + o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] o_ptrs = out + dest_index * stride_o_bs + o_offsets k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) return - - + @torch.no_grad() def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): seq_len = dest_index_ptr.shape[0] @@ -47,16 +49,18 @@ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): head_dim = k_ptr.shape[2] assert head_num == out.shape[1], "head_num should be the same for k_ptr and out" assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" - + num_warps = 2 _fwd_copy_kv_cache_dest[(seq_len,)]( - k_ptr, dest_index_ptr, out, - k_ptr.stride(0), - k_ptr.stride(1), + k_ptr, + dest_index_ptr, + out, + k_ptr.stride(0), + k_ptr.stride(1), k_ptr.stride(2), - out.stride(0), - out.stride(1), + out.stride(0), + out.stride(1), out.stride(2), head_num, BLOCK_DMODEL=head_dim, @@ -65,5 +69,3 @@ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): num_stages=2, ) return - - diff --git a/colossalai/kernel/triton/fused_layernorm.py b/colossalai/kernel/triton/fused_layernorm.py index 99800acfbb92..24083b050808 100644 --- a/colossalai/kernel/triton/fused_layernorm.py +++ b/colossalai/kernel/triton/fused_layernorm.py @@ -3,6 +3,7 @@ try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -14,13 +15,13 @@ @triton.jit def _layer_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero BLOCK_SIZE: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. @@ -32,15 +33,15 @@ def _layer_norm_fwd_fused( _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) _mean += a mean = tl.sum(_mean, axis=0) / N # Compute variance _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) @@ -50,7 +51,7 @@ def _layer_norm_fwd_fused( mask = cols < N w = tl.load(W + cols, mask=mask) b = tl.load(B + cols, mask=mask) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = (x - mean) * rstd y = x_hat * w + b # Write output @@ -71,13 +72,7 @@ def layer_norm(x, weight, bias, eps): # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel - _layer_norm_fwd_fused[(M,)](x_arg, - y, - weight, - bias, - x_arg.stride(0), - N, - eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps) + _layer_norm_fwd_fused[(M,)]( + x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps + ) return y diff --git a/colossalai/kernel/triton/qkv_matmul_kernel.py b/colossalai/kernel/triton/qkv_matmul_kernel.py index 62fc6bba0360..7b5cd2923f0e 100644 --- a/colossalai/kernel/triton/qkv_matmul_kernel.py +++ b/colossalai/kernel/triton/qkv_matmul_kernel.py @@ -1,7 +1,7 @@ -import torch try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -9,9 +9,10 @@ if HAS_TRITON: - ''' + """ this kernel function is modified from https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html - ''' + """ + @triton.jit def qkv_gemm_4d_kernel( a_ptr, @@ -34,12 +35,12 @@ def qkv_gemm_4d_kernel( stride_cn, scale, # Meta-parameters - BLOCK_SIZE_M : tl.constexpr = 64, - BLOCK_SIZE_N : tl.constexpr = 32, - BLOCK_SIZE_K : tl.constexpr = 32, - GROUP_SIZE_M : tl.constexpr = 8, + BLOCK_SIZE_M: tl.constexpr = 64, + BLOCK_SIZE_N: tl.constexpr = 32, + BLOCK_SIZE_K: tl.constexpr = 32, + GROUP_SIZE_M: tl.constexpr = 8, ): - r""" A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, + r"""A kernel function which is used to do batch-matmul for Q*K^T or score_matrix * V for attention layer, where score_matrix is softmax(Q*V^T/sqrt(hidden_size)) Args: a_ptr(torch.Tensor): pointer to input tensor array (bs, M, h, K) or (bs, h, M, K) @@ -53,21 +54,21 @@ def qkv_gemm_4d_kernel( stride_bh(tl.constexpr): stride for h-dimention for tensor array B stride_bk(tl.constexpr): stride for k-dimention for tensor array B stride_bn(tl.constexpr): stride for n-dimention for tensor array B - stride_cb(tl.constexpr): stride for bs-dimention for tensor array output + stride_cb(tl.constexpr): stride for bs-dimention for tensor array output stride_ch(tl.constexpr): stride for h-dimention for tensor array output stride_cm(tl.constexpr): stride for m-dimention for tensor array output stride_cn(tl.constexpr): stride for n-dimention for tensor array output BLOCK_SIZE_M : tiling size for M-dimension of tensor Array a BLOCK_SIZE_N : tiling size for N-dimension of tensor Array b BLOCK_SIZE_K : tiling size for K-dimension of a and b - GROUP_SIZE_M : group size for reducing cache miss, more details: + GROUP_SIZE_M : group size for reducing cache miss, more details: """ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - batch = tl.program_id(axis = 0) - head = tl.program_id(axis = 1) - pid = tl.program_id(axis = 2) + batch = tl.program_id(axis=0) + head = tl.program_id(axis=1) + pid = tl.program_id(axis=2) # the following is from tutorial: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -77,33 +78,38 @@ def qkv_gemm_4d_kernel( pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + - (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) - b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + - (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + a_ptrs = ( + a_ptr + batch * stride_ab + head * stride_ah + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + ) + b_ptrs = ( + b_ptr + batch * stride_bb + head * stride_bh + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) - a = tl.load(a_ptrs, mask=a_mask, other=0.) - b = tl.load(b_ptrs, mask=b_mask, other=0.) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - + accumulator = accumulator.to(c_ptr.dtype.element_ty) if scale > 0: accumulator = accumulator * scale.to(c_ptr.dtype.element_ty) - offs_accumu_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_accumu_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_accumu_m[:, None] + - stride_cn * offs_accumu_n[None, :]) + c_ptrs = ( + c_ptr + + batch * stride_cb + + head * stride_ch + + stride_cm * offs_accumu_m[:, None] + + stride_cn * offs_accumu_n[None, :] + ) accumulator_mask = (offs_accumu_m[:, None] < M) & (offs_accumu_n[None, :] < N) tl.store(c_ptrs, accumulator, mask=accumulator_mask) diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py index 1fb79115f8ce..d5d6f9d85df1 100644 --- a/colossalai/kernel/triton/rms_norm.py +++ b/colossalai/kernel/triton/rms_norm.py @@ -3,17 +3,19 @@ try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") - + if HAS_TRITON: - ''' - this kernel function is modified from - https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py - ''' + """ + this kernel function is modified from + https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py + """ + @triton.jit def _rms_norm_fwd_fused( X, # pointer to the input @@ -32,7 +34,7 @@ def _rms_norm_fwd_fused( _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) @@ -41,13 +43,12 @@ def _rms_norm_fwd_fused( cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = x * rstd y = x_hat * w # Write output tl.store(Y + cols, y.to(tl.float16), mask=mask) - def rmsnorm_forward(x, weight, eps): # allocate output y = torch.empty_like(x) @@ -66,7 +67,5 @@ def rmsnorm_forward(x, weight, eps): BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2 num_warps = 8 # enqueue kernel - _rms_norm_fwd_fused[(M,)](x_arg, y, weight, - x_arg.stride(0), N, eps, - BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + _rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) return y diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py index d9d1b2bcf026..eb43fab7935c 100644 --- a/colossalai/kernel/triton/rotary_embedding_kernel.py +++ b/colossalai/kernel/triton/rotary_embedding_kernel.py @@ -29,19 +29,29 @@ def _rotary_kernel( dim_range0 = tl.arange(0, HEAD_DIM // 2) dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ - None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride - off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ - None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride + off_q0 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range0[None, None, :] * q_d_stride + ) + off_q1 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range1[None, None, :] * q_d_stride + ) off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride - q0 = tl.load(q + off_q0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0) - q1 = tl.load(q + off_q1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0) + q0 = tl.load( + q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + q1 = tl.load( + q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) @@ -49,12 +59,16 @@ def _rotary_kernel( out0 = q0 * cos - q1 * sin out1 = q0 * sin + q1 * cos - tl.store(q + off_q0, - out0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) - tl.store(q + off_q1, - out1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) + tl.store( + q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + tl.store( + q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) return diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py index 6ae54dcb0b38..4b56c8afd67f 100644 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -1,9 +1,8 @@ import torch -from torch import nn try: import triton - import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -13,9 +12,10 @@ from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .softmax import softmax_kernel - def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - input_mask: torch.Tensor, scale: float): - r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels + def self_attention_forward_without_fusion( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float + ): + r"""A function to do QKV Attention calculation by calling GEMM and softmax triton kernels Args: q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) @@ -65,7 +65,7 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t score_output.stride(2), score_output.stride(3), scale=scale, - # currently manually setting, later on we can use auto-tune config to match best setting + # currently manually setting, later on we can use auto-tune config to match best setting BLOCK_SIZE_M=64, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32, @@ -79,7 +79,6 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t n_rows, n_cols = score_output.shape if n_rows <= 350000: - block_size = max(triton.next_power_of_2(n_cols), 2) num_warps = 4 if block_size >= 4096: @@ -142,15 +141,9 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t ) return output.view(batches, -1, d_model) - def self_attention_compute_using_triton(qkv, - input_mask, - layer_past, - alibi, - scale, - head_size, - triangular=False, - use_flash=False): - + def self_attention_compute_using_triton( + qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False + ): assert qkv.is_contiguous() assert alibi is None, "current triton self-attention does not support alibi" batches = qkv.shape[0] @@ -158,8 +151,8 @@ def self_attention_compute_using_triton(qkv, num_of_heads = d_model // head_size q = qkv[:, :, :d_model] - k = qkv[:, :, d_model:d_model * 2] - v = qkv[:, :, d_model * 2:] + k = qkv[:, :, d_model : d_model * 2] + v = qkv[:, :, d_model * 2 :] q = q.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) diff --git a/colossalai/kernel/triton/softmax.py b/colossalai/kernel/triton/softmax.py index c65adaf40dda..8ffce80a3041 100644 --- a/colossalai/kernel/triton/softmax.py +++ b/colossalai/kernel/triton/softmax.py @@ -1,39 +1,42 @@ import torch + try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") if HAS_TRITON: - ''' - softmax kernel is modified based on + """ + softmax kernel is modified based on https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py - ''' + """ + @triton.jit def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): - r""" the kernel function for implementing softmax operator + r"""the kernel function for implementing softmax operator Args: output_ptr: the output after finishing softmax operation, (N, hidden_dim) input_ptr: the tensor of input, shape should be (N, hidden_dim) n_cols(tl.constexpr): the number of cols of input - BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim + BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim """ row_idx = tl.program_id(0) row_start_ptr = input_ptr + row_idx * row_stride col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets - row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float("inf")).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) if mask_ptr is not None: - # load mask into SRAM + # load mask into SRAM mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) - # update + # update row_minus_max = row_minus_max + mask numerator = tl.exp(row_minus_max) @@ -43,17 +46,16 @@ def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SI output_ptrs = output_row_start_ptr + col_offsets # Write back output to DRAM tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) - - + def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: if mask is not None: assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" - assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" - + assert dim == -1 or dim == len(input.shape) - 1, "currently softmax layer only support last dimention" + hidden_dim = input.shape[-1] output = torch.empty_like(input) input = input.view(-1, hidden_dim) - if mask is not None: + if mask is not None: mask = mask.view(-1, hidden_dim) assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" @@ -67,30 +69,31 @@ def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Ten else: num_warps = 4 - if num_rows <= 350000: + if num_rows <= 350000: grid = (num_rows,) - softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) + softmax_kernel[grid]( + output, input, input.stride(0), num_cols, mask, BLOCK_SIZE=block_size, num_warps=num_warps + ) else: grid = lambda meta: () - grid = lambda meta: ( - triton.cdiv(num_rows, meta["BLOCK_M"]), - ) + grid = lambda meta: (triton.cdiv(num_rows, meta["BLOCK_M"]),) - BLOCK_M = 32 if block_size >= 4096: - BLOCK_M = 4 + pass elif block_size >= 2048: - BLOCK_M = 8 + pass - softmax_kernel[grid](output_ptr = output, - input_ptr = input, - row_stride = input.stride(0), - n_rows = num_rows, - n_cols = num_cols, - mask_ptr = mask, - # currently manually setting up size - BLOCK_M = 32, - BLOCK_SIZE = block_size) + softmax_kernel[grid]( + output_ptr=output, + input_ptr=input, + row_stride=input.stride(0), + n_rows=num_rows, + n_cols=num_cols, + mask_ptr=mask, + # currently manually setting up size + BLOCK_M=32, + BLOCK_SIZE=block_size, + ) - return output \ No newline at end of file + return output diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index c6b25f4abcec..7d0f9708516a 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -1,12 +1,12 @@ # Adapted from ModelTC https://github.com/ModelTC/lightllm -import math import torch try: import triton import triton.language as tl + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -15,10 +15,28 @@ if HAS_TRITON: @triton.jit - def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, - attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, q_batch_stride, q_head_stride, - q_head_dim_stride, k_batch_stride, k_head_stride, k_head_dim_stride, attn_head_stride, - attn_batch_stride, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr): + def _token_attn_1_kernel( + Q, + K, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): current_batch = tl.program_id(0) current_head = tl.program_id(1) start_n = tl.program_id(2) @@ -40,9 +58,11 @@ def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_ca for start_mark in range(0, block_mask, 1): q = tl.load(Q + off_q + start_mark) offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0) + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) @@ -52,11 +72,29 @@ def _token_attn_1_kernel(Q, K, sm_scale, kv_cache_loc, kv_cache_start_loc, kv_ca return @triton.jit - def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, - max_kv_cache_len, attn_out, kv_cache_loc_b_stride, kv_cache_loc_s_stride, - q_batch_stride, q_head_stride, q_head_dim_stride, k_batch_stride, k_head_stride, - k_head_dim_stride, attn_head_stride, attn_batch_stride, HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr): + def _token_attn_1_alibi_kernel( + Q, + K, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): current_batch = tl.program_id(0) current_head = tl.program_id(1) start_n = tl.program_id(2) @@ -79,9 +117,11 @@ def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_sta alibi_m = tl.load(alibi + current_head) q = tl.load(Q + off_q + start_mark) offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load(kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0) + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) att_value = tl.sum(q[None, :] * k, 1) @@ -92,14 +132,9 @@ def _token_attn_1_alibi_kernel(Q, K, sm_scale, alibi, kv_cache_loc, kv_cache_sta return @torch.no_grad() - def token_attn_fwd_1(q, - k, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - alibi=None): + def token_attn_fwd_1( + q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None + ): BLOCK = 32 # shape constraints q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] @@ -168,9 +203,17 @@ def token_attn_fwd_1(q, return @triton.jit - def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, - logics_head_dim_stride, logics_batch_stride, prob_head_dim_stride, prob_batch_stride, - BLOCK_SIZE: tl.constexpr): + def _token_attn_softmax_fwd( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + logics_head_dim_stride, + logics_batch_stride, + prob_head_dim_stride, + prob_batch_stride, + BLOCK_SIZE: tl.constexpr, + ): current_batch = tl.program_id(0) current_head = tl.program_id(1) @@ -178,20 +221,26 @@ def _token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - row = tl.load(softmax_logics + current_head * logics_head_dim_stride + - (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, - mask=col_offsets < current_batch_seq_len, - other=-float('inf')).to(tl.float32) + row = tl.load( + softmax_logics + + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) row_minus_max = row - tl.max(row, axis=0) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator - tl.store(softmax_prob_out + current_head * prob_head_dim_stride + - (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, - softmax_output, - mask=col_offsets < current_batch_seq_len) + tl.store( + softmax_prob_out + + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len, + ) return @torch.no_grad() @@ -220,11 +269,27 @@ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, return @triton.jit - def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, - kv_cache_loc_b_stride, kv_cache_loc_s_stride, prob_head_dim_stride, prob_batch_stride, - v_batch_stride, v_head_stride, v_head_dim_stride, attn_out_batch_stride, - attn_out_head_stride, attn_out_head_dim_stride, HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr): + def _token_attn_2_kernel( + Prob, + V, + attn_out, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + prob_head_dim_stride, + prob_batch_stride, + v_batch_stride, + v_head_stride, + v_head_dim_stride, + attn_out_batch_stride, + attn_out_head_stride, + attn_out_head_dim_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): current_batch = tl.program_id(0) current_head = tl.program_id(1) @@ -232,7 +297,6 @@ def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv offs_d = tl.arange(0, HEAD_DIM) current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = current_batch_seq_len current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride @@ -242,19 +306,29 @@ def _token_attn_2_kernel(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv acc = tl.zeros([HEAD_DIM], dtype=tl.float32) for start_n in range(0, current_batch_seq_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load(Prob + p_offs + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0) - v_loc = tl.load(kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0) - v_value = tl.load(V + v_offs + v_loc[:, None] * v_batch_stride, - mask=(start_n + offs_n[:, None]) < current_batch_seq_len, - other=0.0) + p_value = tl.load( + Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_loc = tl.load( + kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0, + ) acc += tl.sum(p_value[:, None] * v_value, 0) acc = acc.to(tl.float16) - off_o = current_batch * attn_out_batch_stride + current_head * attn_out_head_stride + offs_d * attn_out_head_dim_stride + off_o = ( + current_batch * attn_out_batch_stride + + current_head * attn_out_head_stride + + offs_d * attn_out_head_dim_stride + ) out_ptrs = attn_out + off_o tl.store(out_ptrs, acc) return @@ -296,15 +370,9 @@ def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cac return @torch.no_grad() - def token_attention_fwd(q, - k, - v, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=None): + def token_attention_fwd( + q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None + ): head_num = k.shape[1] batch_size = kv_cache_seq_len.shape[0] calcu_shape1 = (batch_size, head_num, k.shape[2]) @@ -312,21 +380,24 @@ def token_attention_fwd(q, att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - token_attn_fwd_1(q.view(calcu_shape1), - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=alibi) + token_attn_fwd_1( + q.view(calcu_shape1), + k, + att_m_tensor, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi, + ) prob = torch.empty_like(att_m_tensor) token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) att_m_tensor = None - token_attn_fwd_2(prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, - max_len_in_batch) + token_attn_fwd_2( + prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch + ) prob = None diff --git a/colossalai/lazy/__init__.py b/colossalai/lazy/__init__.py index 4387107bf773..c6b813c50036 100644 --- a/colossalai/lazy/__init__.py +++ b/colossalai/lazy/__init__.py @@ -1,6 +1,6 @@ from .lazy_init import LazyInitContext, LazyTensor __all__ = [ - 'LazyInitContext', - 'LazyTensor', + "LazyInitContext", + "LazyTensor", ] diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index e071563c045a..ebaf2e1600fc 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from types import MethodType from typing import Callable, Dict, Optional, Union @@ -35,43 +34,43 @@ "eye", ] -_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split'] +_EARLY_MATERIALIZED_OPS = ["__getitem__", "split"] # If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset) # without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block. # These ops cannot be unwrapped using .data -_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__', '__set__', 'numel', 'size', 'dim'] +_CHANGE_META_OPS = ["_cudnn_rnn_flatten_weight", "requires_grad_", "__get__", "__set__", "numel", "size", "dim"] _LEGACY_TENSOR_CONSTRUCTOR = { - 'FloatTensor': torch.float, - 'DoubleTensor': torch.double, - 'HalfTensor': torch.half, - 'BFloat16Tensor': torch.bfloat16, - 'ByteTensor': torch.uint8, - 'CharTensor': torch.int8, - 'ShortTensor': torch.short, - 'IntTensor': torch.int, - 'LongTensor': torch.long, - 'BoolTensor': torch.bool, + "FloatTensor": torch.float, + "DoubleTensor": torch.double, + "HalfTensor": torch.half, + "BFloat16Tensor": torch.bfloat16, + "ByteTensor": torch.uint8, + "CharTensor": torch.int8, + "ShortTensor": torch.short, + "IntTensor": torch.int, + "LongTensor": torch.long, + "BoolTensor": torch.bool, } _EMPTY_DATA = torch.empty(0) class _MyTensor(Tensor): - """This class is only for correctness verification. - """ - _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + """This class is only for correctness verification.""" + + _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None default_device: Optional[torch.device] = None - def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor': + def __new__(cls, func, *args, concrete_data=None, **kwargs) -> "_MyTensor": cls._pre_op_fn() if concrete_data is not None: # uniform api as LazyTensor data = concrete_data else: - kwargs['device'] = cls.default_device + kwargs["device"] = cls.default_device data = func(*args, **kwargs) return Tensor._make_subclass(cls, data, require_grad=data.requires_grad) @@ -82,12 +81,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def _data_tolist(tensor: torch.Tensor) -> list: - """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor. - """ + """tolist() method is not allowed for a subclass of tensor. Tensor.data returns a Tensor.""" return tensor.data.tolist() -def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: +def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor: """Convert a lazy tensor's class to target's class, with target's data. The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models. @@ -104,7 +102,7 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor: tensor.__class__ = cls_to_become if cls_to_become is Parameter: # to fit UninitializedParameter - delattr(tensor, '_is_param') + delattr(tensor, "_is_param") tensor.data = target tensor.requires_grad = target.requires_grad # subclass of torch.Tensor does not have tolist() method @@ -147,8 +145,8 @@ class LazyTensor(torch.Tensor): """ _repr = True - _meta_data: Optional[MetaTensor] = None # shape, dtype, device - _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None + _meta_data: Optional[MetaTensor] = None # shape, dtype, device + _pre_op_fn: Callable[["LazyTensor"], None] = lambda *args: None default_device: Optional[torch.device] = None @@ -159,8 +157,8 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): elem = concrete_data else: if meta_data is None: - device = kwargs.get('device', 'cpu') - elem = func(*args, **{**kwargs, 'device': 'meta'}) + device = kwargs.get("device", "cpu") + elem = func(*args, **{**kwargs, "device": "meta"}) meta_data = MetaTensor(elem, device=device) elem = meta_data._tensor # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here @@ -170,10 +168,10 @@ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs): def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs): if func.__name__ in _NORMAL_FACTORY: - kwargs = {**kwargs, 'device': LazyTensor.default_device} - self._factory_method = (func, args, kwargs) # (func, args, kwargs) - self._op_buffer = [] # (func, args, kwargs, replace) - self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data + kwargs = {**kwargs, "device": LazyTensor.default_device} + self._factory_method = (func, args, kwargs) # (func, args, kwargs) + self._op_buffer = [] # (func, args, kwargs, replace) + self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data def materialize(self) -> torch.Tensor: """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace). @@ -200,12 +198,11 @@ def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> to return _convert_cls(self, local_tensor) def clean(self) -> None: - """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized. - """ - delattr(self, '_factory_method') - delattr(self, '_op_buffer') - delattr(self, '_materialized_data') - delattr(self, '_meta_data') + """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.""" + delattr(self, "_factory_method") + delattr(self, "_op_buffer") + delattr(self, "_materialized_data") + delattr(self, "_meta_data") @staticmethod def _replace_with_materialized(x): @@ -221,8 +218,9 @@ def _materialize_data(self) -> torch.Tensor: # apply cached sequence self._pre_op_fn() - init_val = func(*tree_map(self._replace_with_materialized, args), - **tree_map(self._replace_with_materialized, kwargs)) + init_val = func( + *tree_map(self._replace_with_materialized, args), **tree_map(self._replace_with_materialized, kwargs) + ) self._materialized_data = self._rerun_ops(init_val) return self._materialized_data @@ -243,13 +241,13 @@ def replace(x): packed = None - for (func, args, kwargs) in self._op_buffer: + for func, args, kwargs in self._op_buffer: if func == torch.Tensor.requires_grad_: - packed = func, args, kwargs # requires grad should be set at last + packed = func, args, kwargs # requires grad should be set at last else: self._pre_op_fn() o = func(*tree_map(replace, args), **tree_map(replace, kwargs)) - target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value + target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value # super-dainiu: set requires_grad after all inplace-ops are done if packed is not None: @@ -268,8 +266,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): # These OPs cannot be lazy and related tensors should be early materialized tree_map(cls._replace_with_materialized, args) tree_map(cls._replace_with_materialized, kwargs) - is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__')) - or func.__name__ in ('__setitem__', '__set__')) + is_inplace: bool = ( + func.__name__.endswith("_") + and not (func.__name__.endswith("__")) + or func.__name__ in ("__setitem__", "__set__") + ) is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS @@ -285,11 +286,11 @@ def unwrap(x): target: LazyTensor = args[0].clone() target._op_buffer.append((func, args, kwargs)) - target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]), - **tree_map(unwrap, kwargs)) + target._meta_data = getattr(target._meta_data, func.name)( + *tree_map(unwrap, args[1:]), **tree_map(unwrap, kwargs) + ) return target else: - meta_to_lazy = {} def unwrap(x): @@ -328,10 +329,9 @@ def wrap(y, i=None): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - pass # skip + pass # skip def clone(self) -> "LazyTensor": - def factory_fn(): # if self is materialized, return self new_tensor = self.materialize() if type(self) is LazyTensor else self @@ -346,8 +346,10 @@ def detach(self) -> Tensor: def __deepcopy__(self, memo): if not self.is_leaf: - raise RuntimeError("Only Tensors created explicitly by the user " - "(graph leaves) support the deepcopy protocol at the moment") + raise RuntimeError( + "Only Tensors created explicitly by the user " + "(graph leaves) support the deepcopy protocol at the moment" + ) if id(self) in memo: return memo[id(self)] @@ -375,7 +377,7 @@ def data(self): return self @data.setter - def data(self, other: 'LazyTensor'): + def data(self, other: "LazyTensor"): """This is sightly different from oringinal `data` setter. E.g.: @@ -413,7 +415,7 @@ def __hash__(self): def __rpow__(self, other): dtype = torch.result_type(self, other) - return torch.tensor(other, dtype=dtype, device=self.device)**self + return torch.tensor(other, dtype=dtype, device=self.device) ** self class LazyInitContext: @@ -444,11 +446,14 @@ class LazyInitContext: 1. Quantization strategies can be applied before allocating real memory. 2. Lazy initialization seems slower than normal initialization. """ + _replaced: bool = False - def __init__(self, - tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor, - default_device: Optional[Union[torch.device, str, int]] = None): + def __init__( + self, + tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor, + default_device: Optional[Union[torch.device, str, int]] = None, + ): assert tensor_cls is LazyTensor or tensor_cls is _MyTensor self.overrides = {} self.tensor_cls = tensor_cls @@ -457,7 +462,7 @@ def __init__(self, def __enter__(self): if LazyInitContext._replaced: - raise RuntimeError(f'LazyInitContext is not reentrant') + raise RuntimeError(f"LazyInitContext is not reentrant") LazyInitContext._replaced = True self.old_default_device = self.tensor_cls.default_device self.tensor_cls.default_device = self.default_device @@ -485,17 +490,17 @@ def wrapper(*args, **kwargs): return args[0] elif len(args) == 1: # (object data, *, torch.device device) - kwargs = {**kwargs, 'dtype': dtype} - replaced, orig = self.overrides['tensor'] + kwargs = {**kwargs, "dtype": dtype} + replaced, orig = self.overrides["tensor"] return replaced(*args, **kwargs) elif _is_int_tuple(args): # (tuple of ints size, *, torch.device device) - kwargs = {**kwargs, 'dtype': dtype} - replaced, orig = self.overrides['empty'] + kwargs = {**kwargs, "dtype": dtype} + replaced, orig = self.overrides["empty"] return replaced(*args, **kwargs) else: raise TypeError( - f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)' + f"new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)" ) return wrapper, target @@ -514,23 +519,29 @@ def wrapper(*args, **kwargs): if callable(getattr(torch, target, None)) } - self.overrides.update({ - target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like')) - for target in _NORMAL_FACTORY - if callable(getattr(torch, target + '_like', None)) - }) - - self.overrides.update({ - target: wrap_legacy_constructor(getattr(torch, target), dtype) - for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() - if callable(getattr(torch, target, None)) - }) - - self.overrides.update({ - target: wrap_no_meta_factory(getattr(torch, target)) - for target in _NO_META_FACTORY - if callable(getattr(torch, target, None)) - }) + self.overrides.update( + { + target + "_like": wrap_factory_like_method(getattr(torch, target), getattr(torch, target + "_like")) + for target in _NORMAL_FACTORY + if callable(getattr(torch, target + "_like", None)) + } + ) + + self.overrides.update( + { + target: wrap_legacy_constructor(getattr(torch, target), dtype) + for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items() + if callable(getattr(torch, target, None)) + } + ) + + self.overrides.update( + { + target: wrap_no_meta_factory(getattr(torch, target)) + for target in _NO_META_FACTORY + if callable(getattr(torch, target, None)) + } + ) for name, (wrapper, orig) in self.overrides.items(): setattr(torch, name, wrapper) @@ -556,10 +567,9 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, - device_mesh: DeviceMesh, - sharding_spec_dict: Dict[str, ShardingSpec], - verbose: bool = False) -> nn.Module: + def distribute( + module: nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: Dict[str, ShardingSpec], verbose: bool = False + ) -> nn.Module: """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -574,9 +584,9 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) -def _apply_to_lazy_module(module: nn.Module, - apply_fn: Callable[[str, torch.Tensor], None], - verbose: bool = False) -> nn.Module: +def _apply_to_lazy_module( + module: nn.Module, apply_fn: Callable[[str, torch.Tensor], None], verbose: bool = False +) -> nn.Module: if verbose: # verbose info param_cnt = 0 @@ -590,7 +600,7 @@ def _apply_to_lazy_module(module: nn.Module, if verbose: param_cnt += 1 total_numel += p.numel() - if getattr(p, '_materialized_data', False) is None: + if getattr(p, "_materialized_data", False) is None: # if no _materialized_data attr, the tensor is not lazy param_lazy_cnt += 1 else: @@ -612,10 +622,11 @@ def _apply_to_lazy_module(module: nn.Module, if verbose: non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0 - _print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}') - _print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}') + _print_rank_0(f"Param lazy rate: {param_lazy_cnt}/{param_cnt}") + _print_rank_0(f"Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}") _print_rank_0( - f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%') + f"Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%" + ) return module diff --git a/colossalai/legacy/__init__.py b/colossalai/legacy/__init__.py index f51941ee800b..4d6ad357a2fa 100644 --- a/colossalai/legacy/__init__.py +++ b/colossalai/legacy/__init__.py @@ -1,9 +1,9 @@ from .initialize import initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch __all__ = [ - 'launch', - 'launch_from_openmpi', - 'launch_from_slurm', - 'launch_from_torch', - 'initialize', + "launch", + "launch_from_openmpi", + "launch_from_slurm", + "launch_from_torch", + "initialize", ] diff --git a/colossalai/legacy/amp/__init__.py b/colossalai/legacy/amp/__init__.py index e83a7f6ac5cd..9d17d88b4c79 100644 --- a/colossalai/legacy/amp/__init__.py +++ b/colossalai/legacy/amp/__init__.py @@ -12,7 +12,7 @@ from .naive_amp import convert_to_naive_amp from .torch_amp import convert_to_torch_amp -__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] +__all__ = ["convert_to_amp", "convert_to_naive_amp", "convert_to_apex_amp", "convert_to_torch_amp", "AMP_TYPE"] def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): @@ -38,8 +38,7 @@ def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mod For ``torch_amp``, please check `torch_amp config `_. """ - assert isinstance(mode, AMP_TYPE), \ - f'expected the argument mode be AMP_TYPE, but got {type(mode)}' + assert isinstance(mode, AMP_TYPE), f"expected the argument mode be AMP_TYPE, but got {type(mode)}" if amp_config is None: amp_config = Config() diff --git a/colossalai/legacy/amp/amp_type.py b/colossalai/legacy/amp/amp_type.py index 6f322f866cfc..5ad5faf08b71 100644 --- a/colossalai/legacy/amp/amp_type.py +++ b/colossalai/legacy/amp/amp_type.py @@ -5,6 +5,6 @@ class AMP_TYPE(Enum): - APEX = 'apex' - TORCH = 'torch' - NAIVE = 'naive' + APEX = "apex" + TORCH = "torch" + NAIVE = "naive" diff --git a/colossalai/legacy/amp/apex_amp/__init__.py b/colossalai/legacy/amp/apex_amp/__init__.py index 51b9b97dccce..680c6e45ca9d 100644 --- a/colossalai/legacy/amp/apex_amp/__init__.py +++ b/colossalai/legacy/amp/apex_amp/__init__.py @@ -34,9 +34,10 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): More details about ``amp_config`` refer to `amp_config `_. """ import apex.amp as apex_amp + model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) optimizer = ApexAMPOptimizer(optimizer) return model, optimizer -__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer'] +__all__ = ["convert_to_apex_amp", "ApexAMPOptimizer"] diff --git a/colossalai/legacy/amp/apex_amp/apex_amp.py b/colossalai/legacy/amp/apex_amp/apex_amp.py index acc051181562..048c51891b17 100644 --- a/colossalai/legacy/amp/apex_amp/apex_amp.py +++ b/colossalai/legacy/amp/apex_amp/apex_amp.py @@ -15,7 +15,7 @@ class ApexAMPOptimizer(OptimizerWrapper): - """ A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm + """A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm methods """ diff --git a/colossalai/legacy/amp/naive_amp/__init__.py b/colossalai/legacy/amp/naive_amp/__init__.py index 2ee84fc763b1..36e402299147 100644 --- a/colossalai/legacy/amp/naive_amp/__init__.py +++ b/colossalai/legacy/amp/naive_amp/__init__.py @@ -41,7 +41,7 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): output_to_fp32 = is_no_pp_or_last_stage() model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) - use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True) + use_dynamic_grad_scaler = amp_config.pop("dynamic_grad_scale", True) if use_dynamic_grad_scaler: scaler_class = DynamicGradScaler else: @@ -57,4 +57,4 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): return model, optimizer -__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer'] +__all__ = ["convert_to_naive_amp", "NaiveAMPOptimizer", "FP16Optimizer"] diff --git a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py index 2733477599f7..97ec57fbd007 100644 --- a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py @@ -21,7 +21,7 @@ except: fused_optim = None -__all__ = ['FP16Optimizer'] +__all__ = ["FP16Optimizer"] def load_fused_optim(): @@ -63,13 +63,15 @@ class FP16Optimizer(Optimizer): verbose (bool, optional): if set to `True`, will print debug info. Default False. """ - def __init__(self, - optimizer: Optimizer, - grad_scaler: BaseGradScaler, - verbose: bool = False, - clip_grad_norm=0, - dp_process_group: ProcessGroup = None, - mp_process_group: ProcessGroup = None): + def __init__( + self, + optimizer: Optimizer, + grad_scaler: BaseGradScaler, + verbose: bool = False, + clip_grad_norm=0, + dp_process_group: ProcessGroup = None, + mp_process_group: ProcessGroup = None, + ): # have a defaults for compatibility with pytorch optim self._optimizer = optimizer self._defaults = optimizer.defaults @@ -117,10 +119,10 @@ def _get_process_group(parallel_mode): fp32_master_params = [] fp32_params = [] # For all the parameters in this group: - for i, param in enumerate(param_group['params']): + for i, param in enumerate(param_group["params"]): if param.requires_grad: # float16 params: - if param.type() in ['torch.cuda.HalfTensor']: + if param.type() in ["torch.cuda.HalfTensor"]: fp16_params.append(param) # Create a fp32 copy @@ -129,7 +131,7 @@ def _get_process_group(parallel_mode): copy_tensor_parallel_attributes(param, fp32_param) # Replace the optimizer params with the new fp32 copy. - param_group['params'][i] = fp32_param + param_group["params"][i] = fp32_param fp32_master_params.append(fp32_param) # Reset existing state dict key to the new main param. @@ -137,11 +139,13 @@ def _get_process_group(parallel_mode): self._optimizer.state[fp32_param] = self._optimizer.state.pop(param) # fp32 params. - elif param.type() == 'torch.cuda.FloatTensor': + elif param.type() == "torch.cuda.FloatTensor": fp32_params.append(param) else: - raise TypeError('Expected parameter of type torch.cuda.FloatTensor ' - f'or torch.cuda.HalfTensor, but got {param.type()}') + raise TypeError( + "Expected parameter of type torch.cuda.FloatTensor " + f"or torch.cuda.HalfTensor, but got {param.type()}" + ) self._fp16_param_groups.append(fp16_params) self._fp32_master_param_groups.append(fp32_master_params) @@ -160,12 +164,12 @@ def _get_process_group(parallel_mode): f"clip_grad_norm = {clip_grad_norm}\n" f"grad_scaler = {self._grad_scaler.__class__.__name__}" f"==========================================", - ranks=[0]) + ranks=[0], + ) @property def max_norm(self): - """Returns the maximum norm of gradient clipping. - """ + """Returns the maximum norm of gradient clipping.""" return self._clip_grad_max_norm @property @@ -211,7 +215,7 @@ def _check_overflow(self): # check for overflow for group in self._optimizer.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is not None and has_inf_or_nan(p.grad): self._found_overflow.fill_(1.0) break @@ -235,7 +239,7 @@ def zero_grad(self, set_to_none=True): # set_to_none = True can save some memory space for param_group in self._optimizer.param_groups: - zero_gard_by_list(param_group['params'], set_to_none=set_to_none) + zero_gard_by_list(param_group["params"], set_to_none=set_to_none) def _get_fp32_param_groups_to_update(self): return self._fp32_master_param_groups + self._fp32_param_groups @@ -262,13 +266,12 @@ def _update_fp16_param_from_fp32_param(self): for fp16_param, fp32_param in zip(fp16_group, fp32_group): fp16_param_data.append(fp16_param.data) fp32_master_param_data.append(fp32_param.data) - _multi_tensor_copy_this_to_that(this=fp32_master_param_data, - that=fp16_param_data, - overflow_buf=self._dummy_overflow_buf) + _multi_tensor_copy_this_to_that( + this=fp32_master_param_data, that=fp16_param_data, overflow_buf=self._dummy_overflow_buf + ) def step(self): - """Update the model parameters. - """ + """Update the model parameters.""" # Copy gradients from model params to main params. self._assign_grad_to_fp32_master_param() @@ -307,14 +310,13 @@ def backward(self, loss): scaled_loss.backward() def state_dict(self): - """Returns the states of the fp16 optimizer as a dict object. - """ + """Returns the states of the fp16 optimizer as a dict object.""" state_dict = {} - state_dict['optimizer'] = self._optimizer.state_dict() + state_dict["optimizer"] = self._optimizer.state_dict() if self.grad_scaler: - state_dict['grad_scaler'] = self.grad_scaler.state_dict() - state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups + state_dict["grad_scaler"] = self.grad_scaler.state_dict() + state_dict["fp32_master_param_groups"] = self._fp32_master_param_groups return state_dict def load_state_dict(self, state_dict): @@ -325,16 +327,17 @@ def load_state_dict(self, state_dict): """ # Optimizer. - self._optimizer.load_state_dict(state_dict['optimizer']) + self._optimizer.load_state_dict(state_dict["optimizer"]) # Grad scaler. - if 'grad_scaler' in state_dict: - self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + if "grad_scaler" in state_dict: + self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) # Copy data for the main params. - if 'fp32_master_param_groups' in state_dict: - for current_group, ckpt_group in zip(self._fp32_master_param_groups, - state_dict['fp32_master_param_groups']): + if "fp32_master_param_groups" in state_dict: + for current_group, ckpt_group in zip( + self._fp32_master_param_groups, state_dict["fp32_master_param_groups"] + ): for current_param, ckpt_param in zip(current_group, ckpt_group): current_param.data.copy_(ckpt_param.data) @@ -346,7 +349,7 @@ def clip_grad_norm(self, clip_grad): """ params = [] for param_group in self._optimizer.param_groups: - for param in param_group['params']: + for param in param_group["params"]: params.append(param) return clip_grad_norm_fp32(params, clip_grad) diff --git a/colossalai/legacy/amp/naive_amp/_utils.py b/colossalai/legacy/amp/naive_amp/_utils.py index 7633705e19fb..aa5a91146bb0 100644 --- a/colossalai/legacy/amp/naive_amp/_utils.py +++ b/colossalai/legacy/amp/naive_amp/_utils.py @@ -27,7 +27,7 @@ def has_inf_or_nan(tensor): raise return True else: - if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum: return True return False diff --git a/colossalai/legacy/amp/naive_amp/naive_amp.py b/colossalai/legacy/amp/naive_amp/naive_amp.py index 1fab3e5a0d0d..f9c298941fa9 100644 --- a/colossalai/legacy/amp/naive_amp/naive_amp.py +++ b/colossalai/legacy/amp/naive_amp/naive_amp.py @@ -45,9 +45,11 @@ def step(self): def clip_grad_norm(self, model: nn.Module, max_norm: float): if self.optim.max_norm == max_norm: return - raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). " - "If you have supplied clip_grad_norm in the amp_config, " - "executing the method clip_grad_norm is not allowed.") + raise RuntimeError( + "NaiveAMP optimizer has clipped gradients during optimizer.step(). " + "If you have supplied clip_grad_norm in the amp_config, " + "executing the method clip_grad_norm is not allowed." + ) class NaiveAMPModel(nn.Module): @@ -66,11 +68,13 @@ class NaiveAMPModel(nn.Module): in `parallel_mode `_. """ - def __init__(self, - model: nn.Module, - output_to_fp32: bool = True, - parallel_mode: ParallelMode = ParallelMode.DATA, - sync_buffer: bool = True): + def __init__( + self, + model: nn.Module, + output_to_fp32: bool = True, + parallel_mode: ParallelMode = ParallelMode.DATA, + sync_buffer: bool = True, + ): super().__init__() self.model = model.half() self._output_to_fp32 = output_to_fp32 diff --git a/colossalai/legacy/amp/torch_amp/__init__.py b/colossalai/legacy/amp/torch_amp/__init__.py index 893cc890d68e..ad2416eef06a 100644 --- a/colossalai/legacy/amp/torch_amp/__init__.py +++ b/colossalai/legacy/amp/torch_amp/__init__.py @@ -9,10 +9,9 @@ from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer -def convert_to_torch_amp(model: nn.Module, - optimizer: Optimizer, - criterion: Optional[_Loss] = None, - amp_config: Optional[Config] = None): +def convert_to_torch_amp( + model: nn.Module, optimizer: Optimizer, criterion: Optional[_Loss] = None, amp_config: Optional[Config] = None +): """A helper function to wrap training components with Pytorch AMP modules Args: @@ -42,4 +41,4 @@ def convert_to_torch_amp(model: nn.Module, return model, optimizer, criterion -__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer'] +__all__ = ["convert_to_torch_amp", "TorchAMPModel", "TorchAMPLoss", "TorchAMPOptimizer"] diff --git a/colossalai/legacy/amp/torch_amp/_grad_scaler.py b/colossalai/legacy/amp/torch_amp/_grad_scaler.py index 543dac6ab5ef..fc1aeec234fd 100644 --- a/colossalai/legacy/amp/torch_amp/_grad_scaler.py +++ b/colossalai/legacy/amp/torch_amp/_grad_scaler.py @@ -23,7 +23,7 @@ class _MultiDeviceReplicator(object): """ def __init__(self, master_tensor: torch.Tensor) -> None: - assert master_tensor.is_cuda or master_tensor.device.type == 'xla' + assert master_tensor.is_cuda or master_tensor.device.type == "xla" self.master = master_tensor self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} @@ -118,7 +118,7 @@ class GradScaler(object): invokes the underlying ``optimizer.step()``, and other methods become no-ops. """ - def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): + def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): if enabled and not torch.cuda.is_available(): warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") self._enabled = False @@ -174,7 +174,7 @@ def scale(self, outputs): # Short-circuit for the common case. if isinstance(outputs, torch.Tensor): - assert outputs.is_cuda or outputs.device.type == 'xla' + assert outputs.is_cuda or outputs.device.type == "xla" if self._scale is None: self._lazy_init_scale_growth_tracker(outputs.device) assert self._scale is not None @@ -186,7 +186,7 @@ def scale(self, outputs): def apply_scale(val): if isinstance(val, torch.Tensor): - assert val.is_cuda or val.device.type == 'xla' + assert val.is_cuda or val.device.type == "xla" if len(stash) == 0: if self._scale is None: self._lazy_init_scale_growth_tracker(val.device) @@ -214,7 +214,7 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict # Google says mypy struggles with defaultdicts type annotations. - per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] with torch.no_grad(): for group in optimizer.param_groups: for param in group["params"]: @@ -238,8 +238,9 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): for device, per_dtype_grads in per_device_and_dtype_grads.items(): for grads in per_dtype_grads.values(): - torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), - per_device_inv_scale.get(device)) + torch._amp_foreach_non_finite_check_and_unscale_( + grads, per_device_found_inf.get(device), per_device_inv_scale.get(device) + ) # For tensor parallel parameters it should be all-reduced over tensor parallel process group if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: vals = [val for val in per_device_found_inf._per_device_tensors.values()] @@ -328,7 +329,7 @@ def step(self, optimizer, *args, **kwargs): .. warning:: Closure use is not currently supported. """ - if (not self._enabled): + if not self._enabled: return optimizer.step(*args, **kwargs) if "closure" in kwargs: @@ -343,7 +344,7 @@ def step(self, optimizer, *args, **kwargs): retval = None - if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + if hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling: # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. # The contract with custom optimizers is that their step() should accept an additional, # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: @@ -391,14 +392,14 @@ def update(self, new_scale=None): if new_scale is not None: # Accept a new user-defined scale. if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] + self._scale.fill_(new_scale) # type: ignore[union-attr] else: reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." # type: ignore[attr-defined] assert isinstance(new_scale, torch.cuda.FloatTensor), reason assert new_scale.numel() == 1, reason assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] + self._scale.copy_(new_scale) # type: ignore[union-attr] else: # Consume shared inf/nan data collected from optimizers to update the scale. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. @@ -416,11 +417,23 @@ def update(self, new_scale=None): found_inf_combined += found_infs[i] if self._higher_than_torch18: - torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, - self._backoff_factor, self._growth_interval) + torch._amp_update_scale_( + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) else: - self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor, - self._backoff_factor, self._growth_interval) + self._scale = torch._amp_update_scale( + _growth_tracker, + _scale, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) @@ -507,13 +520,17 @@ def state_dict(self): If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` should be called after :meth:`update`. """ - return { - "scale": self.get_scale(), - "growth_factor": self._growth_factor, - "backoff_factor": self._backoff_factor, - "growth_interval": self._growth_interval, - "_growth_tracker": self._get_growth_tracker() - } if self._enabled else {} + return ( + { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker(), + } + if self._enabled + else {} + ) def load_state_dict(self, state_dict): r""" @@ -526,8 +543,10 @@ def load_state_dict(self, state_dict): return if len(state_dict) == 0: - raise RuntimeError("The source state dict is empty, possibly because it was saved " - "from a disabled instance of GradScaler.") + raise RuntimeError( + "The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler." + ) self._init_scale = state_dict["scale"] if self._scale is not None: @@ -542,15 +561,17 @@ def load_state_dict(self, state_dict): def __getstate__(self): state = self.__dict__.copy() if self._enabled: - assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ - "of an iteration, or at the end after scaler.update()." + assert len(self._per_optimizer_states) == 0, ( + "A GradScaler instance may only be pickled at the beginning " + "of an iteration, or at the end after scaler.update()." + ) # Pickling _scale and _growth_tracker Tensors directly triggers # "warnings.warn("pickle support for Storage will be removed in 1.5..." # so instead, we set the unpickled instance up to reinitialize them lazily. - state['_init_scale'] = self.get_scale() - state['_init_growth_tracker'] = self._get_growth_tracker() - state['_scale'] = None - state['_growth_tracker'] = None + state["_init_scale"] = self.get_scale() + state["_init_growth_tracker"] = self._get_growth_tracker() + state["_scale"] = None + state["_growth_tracker"] = None return state def __setstate__(self, state): @@ -562,8 +583,9 @@ def _check_inf_per_device(self, optimizer): dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) - self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ - self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = self._unscale_grads_( + optimizer, dummy_inv_scale, found_inf, True + ) return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/colossalai/legacy/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py index c45a5956a205..ced5cc3e6647 100644 --- a/colossalai/legacy/amp/torch_amp/torch_amp.py +++ b/colossalai/legacy/amp/torch_amp/torch_amp.py @@ -42,8 +42,7 @@ def backward(self, loss: Tensor): self.scaler.scale(loss).backward() def step(self): - """Update the parameters of the model - """ + """Update the parameters of the model""" self.scaler.step(self.optim) self.scaler.update() diff --git a/colossalai/legacy/builder/__init__.py b/colossalai/legacy/builder/__init__.py index cf09e1e7a31a..9af3d139d3b1 100644 --- a/colossalai/legacy/builder/__init__.py +++ b/colossalai/legacy/builder/__init__.py @@ -1,3 +1,3 @@ from .builder import build_from_config, build_from_registry, build_gradient_handler -__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry'] +__all__ = ["build_gradient_handler", "build_from_config", "build_from_registry"] diff --git a/colossalai/legacy/builder/builder.py b/colossalai/legacy/builder/builder.py index ff14f46dc61f..dec3bc1c2487 100644 --- a/colossalai/legacy/builder/builder.py +++ b/colossalai/legacy/builder/builder.py @@ -19,7 +19,7 @@ def build_from_config(module, config: dict): AssertionError: Raises an AssertionError if `module` is not a class """ - assert inspect.isclass(module), 'module must be a class' + assert inspect.isclass(module), "module must be a class" return module(**config) @@ -45,15 +45,15 @@ def build_from_registry(config, registry: Registry): Raises: Exception: Raises an Exception if an error occurred when building from registry. """ - config_ = config.copy() # keep the original config untouched - assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}' + config_ = config.copy() # keep the original config untouched + assert isinstance(registry, Registry), f"Expected type Registry but got {type(registry)}" - mod_type = config_.pop('type') - assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}' + mod_type = config_.pop("type") + assert registry.has(mod_type), f"{mod_type} is not found in registry {registry.name}" try: obj = registry.get_module(mod_type)(**config_) except Exception as e: - print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True) + print(f"An error occurred when building {mod_type} from registry {registry.name}", flush=True) raise e return obj @@ -74,6 +74,6 @@ def build_gradient_handler(config, model, optimizer): An object of :class:`colossalai.legacy.engine.BaseGradientHandler` """ config_ = config.copy() - config_['model'] = model - config_['optimizer'] = optimizer + config_["model"] = model + config_["optimizer"] = optimizer return build_from_registry(config_, GRADIENT_HANDLER) diff --git a/colossalai/legacy/communication/__init__.py b/colossalai/legacy/communication/__init__.py index 88ad0487b785..f4492b074425 100644 --- a/colossalai/legacy/communication/__init__.py +++ b/colossalai/legacy/communication/__init__.py @@ -14,21 +14,21 @@ from .utils import recv_obj_meta, send_obj_meta __all__ = [ - 'all_gather', - 'reduce_scatter', - 'all_reduce', - 'broadcast', - 'reduce', - 'send_forward', - 'send_forward_recv_forward', - 'send_forward_backward_recv_forward_backward', - 'send_backward', - 'send_backward_recv_backward', - 'send_backward_recv_forward', - 'send_forward_recv_backward', - 'recv_backward', - 'recv_forward', - 'ring_forward', - 'send_obj_meta', - 'recv_obj_meta', + "all_gather", + "reduce_scatter", + "all_reduce", + "broadcast", + "reduce", + "send_forward", + "send_forward_recv_forward", + "send_forward_backward_recv_forward_backward", + "send_backward", + "send_backward_recv_backward", + "send_backward_recv_forward", + "send_forward_recv_backward", + "recv_backward", + "recv_forward", + "ring_forward", + "send_obj_meta", + "recv_obj_meta", ] diff --git a/colossalai/legacy/communication/collective.py b/colossalai/legacy/communication/collective.py index 7471188226f0..9cf30f733dee 100644 --- a/colossalai/legacy/communication/collective.py +++ b/colossalai/legacy/communication/collective.py @@ -9,10 +9,10 @@ from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc -_all_gather_func = dist._all_gather_base \ - if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor -_reduce_scatter_func = dist._reduce_scatter_base \ - if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor +_all_gather_func = dist._all_gather_base if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor +_reduce_scatter_func = ( + dist._reduce_scatter_base if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor +) def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: @@ -50,11 +50,9 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: return out -def reduce_scatter(tensor: Tensor, - dim: int, - parallel_mode: ParallelMode, - op: ReduceOp = ReduceOp.SUM, - async_op: bool = False) -> Tensor: +def reduce_scatter( + tensor: Tensor, dim: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False +) -> Tensor: r"""Reduces all tensors then scatters it in a specific dimension to all members in the parallel group. @@ -93,10 +91,9 @@ def reduce_scatter(tensor: Tensor, return out -def all_reduce(tensor: Tensor, - parallel_mode: ParallelMode, - op: ReduceOp = ReduceOp.SUM, - async_op: bool = False) -> Tensor: +def all_reduce( + tensor: Tensor, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False +) -> Tensor: r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. Note: @@ -201,16 +198,17 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s if dist.distributed_c10d._rank_not_in_group(group): return - if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1): + if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1: raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.") # set tensor device to cuda if backend is nccl - device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu") + device = torch.cuda.current_device() if dist.get_backend(group) == "nccl" else torch.device("cpu") - my_rank = dist.get_rank() # use global rank + my_rank = dist.get_rank() # use global rank if my_rank == src: tensor_list, tensor_sizes = zip( - *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]) + *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list] + ) tensor_list = list(map(lambda x: x.to(device), tensor_list)) tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes)) diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py index e3f9108ab840..19c3919b6e29 100644 --- a/colossalai/legacy/communication/p2p.py +++ b/colossalai/legacy/communication/p2p.py @@ -82,16 +82,18 @@ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): ops_queue.append(op_to_add) -def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, - object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, - recv_prev: bool = False, - recv_next: bool = False, - recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, - recv_next_shape: Union[torch.Size, List[torch.Size]] = None, - prev_rank: int = None, - next_rank: int = None, - dtype: torch.dtype = None, - scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: +def _communicate( + object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, + object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, + recv_prev: bool = False, + recv_next: bool = False, + recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, + recv_next_shape: Union[torch.Size, List[torch.Size]] = None, + prev_rank: int = None, + next_rank: int = None, + dtype: torch.dtype = None, + scatter_gather_tensors: bool = False, +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: """ Adapted from megatron.p2p_communication. Communicate tensors between stages. Used as helper method in other @@ -123,13 +125,15 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non if recv_prev: assert recv_prev_shape is not None - tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype, - scatter_gather_tensors) + tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes( + recv_prev_shape, dtype, scatter_gather_tensors + ) if recv_next: assert recv_next_shape is not None - tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype, - scatter_gather_tensors) + tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes( + recv_next_shape, dtype, scatter_gather_tensors + ) if object_send_prev is not None or recv_prev: if prev_rank is None: @@ -170,24 +174,25 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() else: for index in range(len(tensor_recv_prev)): - tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view( - recv_prev_shape[index]).requires_grad_() + tensor_recv_prev[index] = ( + gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_() + ) if recv_next and recv_next_split: if isinstance(tensor_recv_next, torch.Tensor): tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() else: for index in range(len(tensor_recv_next)): - tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view( - recv_next_shape[index]).requires_grad_() + tensor_recv_next[index] = ( + gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_() + ) return tensor_recv_prev, tensor_recv_next -def recv_forward(input_tensor_shape, - prev_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def recv_forward( + input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False +) -> Union[torch.Tensor, List[torch.Tensor]]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. Args: @@ -200,18 +205,19 @@ def recv_forward(input_tensor_shape, if gpc.is_pipeline_first_stage(): input_tensor = None else: - input_tensor, _ = _communicate(recv_prev=True, - recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + input_tensor, _ = _communicate( + recv_prev=True, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return input_tensor -def recv_backward(output_grad_shape, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def recv_backward( + output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False +) -> Union[torch.Tensor, List[torch.Tensor]]: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. Args: @@ -224,11 +230,13 @@ def recv_backward(output_grad_shape, if gpc.is_pipeline_last_stage(): output_tensor_grad = None else: - _, output_tensor_grad = _communicate(recv_next=True, - recv_next_shape=output_grad_shape, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + _, output_tensor_grad = _communicate( + recv_next=True, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return output_tensor_grad @@ -251,17 +259,14 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals prev_rank (int, optional): The rank of the recipient of the tensor """ if not gpc.is_pipeline_first_stage(): - _communicate(object_send_prev=input_tensor_grad, - prev_rank=prev_rank, - scatter_gather_tensors=scatter_gather_tensors) + _communicate( + object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors + ) -def send_forward_recv_backward(output_tensor, - output_grad_shape, - recv_next=True, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def send_forward_recv_backward( + output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False +) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the gradient tensor from the next stage in pipeline as the input gradient tensor of this stage. @@ -276,21 +281,25 @@ def send_forward_recv_backward(output_tensor, if gpc.is_pipeline_last_stage(): output_tensor_grad = None else: - _, output_tensor_grad = _communicate(object_send_next=output_tensor, - recv_next=recv_next, - recv_next_shape=output_grad_shape, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + _, output_tensor_grad = _communicate( + object_send_next=output_tensor, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return output_tensor_grad -def send_backward_recv_forward(input_tensor_grad, - input_tensor_shape, - recv_prev=True, - prev_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def send_backward_recv_forward( + input_tensor_grad, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the gradient tensor to the previous stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage. @@ -305,22 +314,26 @@ def send_backward_recv_forward(input_tensor_grad, if gpc.is_pipeline_first_stage(): input_tensor = None else: - input_tensor, _ = _communicate(object_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + input_tensor, _ = _communicate( + object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return input_tensor -def send_forward_recv_forward(output_tensor, - input_tensor_shape, - recv_prev=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def send_forward_recv_forward( + output_tensor, + input_tensor_shape, + recv_prev=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage. @@ -332,23 +345,27 @@ def send_forward_recv_forward(output_tensor, Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. """ - input_tensor, _ = _communicate(object_send_next=output_tensor, - recv_prev=recv_prev, - recv_prev_shape=input_tensor_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + input_tensor, _ = _communicate( + object_send_next=output_tensor, + recv_prev=recv_prev, + recv_prev_shape=input_tensor_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return input_tensor -def send_backward_recv_backward(input_tensor_grad, - output_grad_shape, - recv_next=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: +def send_backward_recv_backward( + input_tensor_grad, + output_grad_shape, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Union[torch.Tensor, List[torch.Tensor]]: """Batched communication operation. Sends the gradient tensor to the previous stage in pipeline, while receives the gradient tensor from the next member in pipeline as the input of this stage. @@ -360,27 +377,30 @@ def send_backward_recv_backward(input_tensor_grad, Returns: Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. """ - _, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad, - recv_next=recv_next, - recv_next_shape=output_grad_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + _, output_tensor_grad = _communicate( + object_send_prev=input_tensor_grad, + recv_next=recv_next, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return output_tensor_grad def send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - input_tensor_shape, - output_grad_shape, - recv_prev=True, - recv_next=True, - prev_rank=None, - next_rank=None, - dtype=torch.float, - scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: + output_tensor, + input_tensor_grad, + input_tensor_shape, + output_grad_shape, + recv_prev=True, + recv_next=True, + prev_rank=None, + next_rank=None, + dtype=torch.float, + scatter_gather_tensors=False, +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: """Batched communication operation. Sends the input tensor to the next stage in pipeline and the gradient tensor to the previous stage, while receives the input gradient tensor from the next stage and the input tensor from the previous stage. @@ -394,14 +414,16 @@ def send_forward_backward_recv_forward_backward( Returns: Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor) """ - input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor, - object_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - recv_prev_shape=input_tensor_shape, - recv_next_shape=output_grad_shape, - prev_rank=prev_rank, - next_rank=next_rank, - dtype=dtype, - scatter_gather_tensors=scatter_gather_tensors) + input_tensor, output_tensor_grad = _communicate( + object_send_next=output_tensor, + object_send_prev=input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + recv_prev_shape=input_tensor_shape, + recv_next_shape=output_grad_shape, + prev_rank=prev_rank, + next_rank=next_rank, + dtype=dtype, + scatter_gather_tensors=scatter_gather_tensors, + ) return input_tensor, output_tensor_grad diff --git a/colossalai/legacy/communication/p2p_v2.py b/colossalai/legacy/communication/p2p_v2.py index 66af214950f2..7c8d8bede069 100644 --- a/colossalai/legacy/communication/p2p_v2.py +++ b/colossalai/legacy/communication/p2p_v2.py @@ -62,10 +62,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - Any: object after unpickled """ buf = tensor.numpy().tobytes()[:tensor_size] - if b'cuda' in buf: + if b"cuda" in buf: buf_array = bytearray(buf) device_index = torch.cuda.current_device() - buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index + buf_array[buf_array.find(b"cuda") + 5] = 48 + device_index buf = bytes(buf_array) io_bytes = io.BytesIO(buf) @@ -123,8 +123,8 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No if local_rank == src: object_tensor = torch.cat(tensor_list) else: - object_tensor = torch.empty( # type: ignore[call-overload] - torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, ) @@ -138,7 +138,7 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No if local_rank != src: for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset:offset + obj_size] + obj_view = object_tensor[offset : offset + obj_size] obj_view = obj_view.type(torch.uint8) if obj_view.device != torch.device("cpu"): obj_view = obj_view.cpu() @@ -147,8 +147,10 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) # unconsistence in device - if isinstance(unpickle_object, - torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): unpickle_object = unpickle_object.cuda() object_list[i] = unpickle_object diff --git a/colossalai/legacy/communication/ring.py b/colossalai/legacy/communication/ring.py index e80192fb578d..a61dae56cd42 100644 --- a/colossalai/legacy/communication/ring.py +++ b/colossalai/legacy/communication/ring.py @@ -28,19 +28,20 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> ops = [] current_rank = gpc.get_global_rank() - tensor_recv_prev = torch.empty(buffer_shape, - requires_grad=True, - device=get_current_device(), - dtype=tensor_send_next.dtype) + tensor_recv_prev = torch.empty( + buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype + ) # send to next rank - send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, - gpc.get_next_global_rank(parallel_mode)) + send_next_op = torch.distributed.P2POp( + torch.distributed.isend, tensor_send_next, gpc.get_next_global_rank(parallel_mode) + ) ops.append(send_next_op) # receive from prev rank - recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, - gpc.get_prev_global_rank(parallel_mode)) + recv_prev_op = torch.distributed.P2POp( + torch.distributed.irecv, tensor_recv_prev, gpc.get_prev_global_rank(parallel_mode) + ) ops.append(recv_prev_op) if current_rank % 2 == 0: diff --git a/colossalai/legacy/communication/utils.py b/colossalai/legacy/communication/utils.py index 7e3dcf1e9820..6d77f3753fe8 100644 --- a/colossalai/legacy/communication/utils.py +++ b/colossalai/legacy/communication/utils.py @@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: if next_rank is None: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} if isinstance(obj, torch.Tensor): send_obj_nums = torch.tensor(1, **tensor_kwargs) dist.send(send_obj_nums, next_rank) @@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: if prev_rank is None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} recv_obj_nums = torch.empty((), **tensor_kwargs) dist.recv(recv_obj_nums, prev_rank) if recv_obj_nums.item() == 1: @@ -122,6 +122,6 @@ def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: numel = torch.numel(tensor) numel_gathered = world_size * numel gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) - chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] + chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) return gathered diff --git a/colossalai/legacy/constants.py b/colossalai/legacy/constants.py index 6cf9085f9fbb..5d64b676e73d 100644 --- a/colossalai/legacy/constants.py +++ b/colossalai/legacy/constants.py @@ -1,32 +1,32 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] -TENSOR_PARALLEL_MODE = 'tensor_parallel_mode' +ALLOWED_MODES = [None, "1d", "2d", "2.5d", "3d", "sequence"] +TENSOR_PARALLEL_MODE = "tensor_parallel_mode" # initializer INITIALIZER_MAPPING = { - 'data': 'Initializer_Data', - 'tensor': 'Initializer_Tensor', - 'pipeline': 'Initializer_Pipeline', - 'embedding': 'Initializer_Embedding', - '1d': 'Initializer_1D', - '2d': 'Initializer_2D', - '2.5d': 'Initializer_2p5D', - '3d': 'Initializer_3D', - 'sequence': 'Initializer_Sequence', - 'model': 'Initializer_Model', - 'moe': 'Initializer_Moe' + "data": "Initializer_Data", + "tensor": "Initializer_Tensor", + "pipeline": "Initializer_Pipeline", + "embedding": "Initializer_Embedding", + "1d": "Initializer_1D", + "2d": "Initializer_2D", + "2.5d": "Initializer_2p5D", + "3d": "Initializer_3D", + "sequence": "Initializer_Sequence", + "model": "Initializer_Model", + "moe": "Initializer_Moe", } # 3D parallelism groups -INPUT_GROUP_3D = 'input_group_3d' -WEIGHT_GROUP_3D = 'weight_group_3d' -OUTPUT_GROUP_3D = 'output_group_3d' -INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d' -OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d' +INPUT_GROUP_3D = "input_group_3d" +WEIGHT_GROUP_3D = "weight_group_3d" +OUTPUT_GROUP_3D = "output_group_3d" +INPUT_X_WEIGHT_3D = "input_x_weight_group_3d" +OUTPUT_X_WEIGHT_3D = "output_x_weight_group_3d" # Attributes of tensor parallel parameters -IS_TENSOR_PARALLEL = 'is_tensor_parallel' -NUM_PARTITIONS = 'num_partitions' +IS_TENSOR_PARALLEL = "is_tensor_parallel" +NUM_PARTITIONS = "num_partitions" TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] diff --git a/colossalai/legacy/context/parallel_context.py b/colossalai/legacy/context/parallel_context.py index 8fdc3d6fea68..48bf8ab279e8 100644 --- a/colossalai/legacy/context/parallel_context.py +++ b/colossalai/legacy/context/parallel_context.py @@ -4,7 +4,6 @@ import random import socket from collections import Counter -from threading import local from typing import Union import numpy as np @@ -95,8 +94,9 @@ def detect_num_processes_on_current_node(self): @staticmethod def _check_parallel_mode(parallel_mode: ParallelMode): - assert isinstance(parallel_mode, ParallelMode), \ - f'expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}' + assert isinstance( + parallel_mode, ParallelMode + ), f"expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}" def get_global_rank(self): """Returns the global rank of the current device. @@ -239,8 +239,10 @@ def is_pipeline_first_stage(self, ignore_virtual=False): def is_pipeline_last_stage(self, ignore_virtual=False): if not ignore_virtual: - if self.virtual_pipeline_parallel_size \ - is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1: + if ( + self.virtual_pipeline_parallel_size is not None + and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1 + ): return False return self.is_last_rank(ParallelMode.PIPELINE) @@ -371,12 +373,12 @@ def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port (str): the master port for distributed training """ # initialize the default process group - init_method = f'tcp://[{host}]:{port}' + init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # None will give the default global process group for pytorch dist operations ranks = list(range(world_size)) - cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None + cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) self.add_global_rank(ParallelMode.GLOBAL, rank) @@ -398,10 +400,11 @@ def check_sanity(self): pps = self.pipeline_parallel_size tps = self.tensor_parallel_size ws = self.world_size - assert ws == dps * pps * \ - tps, f"Expected the world size {ws} to be equal to data" \ - f" parallel size ({dps}) * pipeline parallel size " \ - f"({pps}) * tensor parallel size ({tps})" + assert ws == dps * pps * tps, ( + f"Expected the world size {ws} to be equal to data" + f" parallel size ({dps}) * pipeline parallel size " + f"({pps}) * tensor parallel size ({tps})" + ) def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str): if key in config: @@ -409,10 +412,11 @@ def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str) if isinstance(ele, int): setattr(self, attr_name, ele) elif isinstance(ele, dict): - setattr(self, attr_name, ele['size']) + setattr(self, attr_name, ele["size"]) else: raise NotImplementedError( - f'{"Parallel configuration does not support this kind of argument, please use int or dict"}') + f'{"Parallel configuration does not support this kind of argument, please use int or dict"}' + ) def init_parallel_groups(self): """Initializes the parallel groups. @@ -427,10 +431,10 @@ def init_parallel_groups(self): self.world_size = world_size # set parallel size as attributes for global context - parallel_config = self.config.get('parallel', None) + parallel_config = self.config.get("parallel", None) if parallel_config is not None: - self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size') - self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size') + self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size") + self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size") # the user should not set the data parallel size manually # instead, it should be calculated based on other parallel config @@ -438,33 +442,33 @@ def init_parallel_groups(self): # get the tensor parallel mode and check tensor_parallel_mode = None - if parallel_config is not None and 'tensor' in \ - parallel_config and 'mode' in parallel_config['tensor']: - tensor_parallel_mode = parallel_config['tensor']['mode'] - assert tensor_parallel_mode in ALLOWED_MODES, \ - f"mode in the parallel config must be set to one of {ALLOWED_MODES}" + if parallel_config is not None and "tensor" in parallel_config and "mode" in parallel_config["tensor"]: + tensor_parallel_mode = parallel_config["tensor"]["mode"] + assert ( + tensor_parallel_mode in ALLOWED_MODES + ), f"mode in the parallel config must be set to one of {ALLOWED_MODES}" env.mode = tensor_parallel_mode self.check_sanity() pg_init = [] # LSG: init data parallel process group for compatibility with other parallel module such as zero - pg_init.append(dict(type=INITIALIZER_MAPPING['data'])) + pg_init.append(dict(type=INITIALIZER_MAPPING["data"])) # LSG: init model parallel process group for compatibility with amp and clip grad - pg_init.append(dict(type=INITIALIZER_MAPPING['model'])) + pg_init.append(dict(type=INITIALIZER_MAPPING["model"])) if self.pipeline_parallel_size > 1: - pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline'])) - pg_init.append(dict(type=INITIALIZER_MAPPING['tensor'])) + pg_init.append(dict(type=INITIALIZER_MAPPING["pipeline"])) + pg_init.append(dict(type=INITIALIZER_MAPPING["tensor"])) # init specific tensor parallel group if tensor_parallel_mode is not None: - tensor_parallel_cfg = parallel_config['tensor'].copy() + tensor_parallel_cfg = parallel_config["tensor"].copy() # remove duplicate parameters - tensor_parallel_cfg.pop('mode') - tensor_parallel_cfg.pop('size') + tensor_parallel_cfg.pop("mode") + tensor_parallel_cfg.pop("size") # add this config to initialize later pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg)) @@ -472,11 +476,16 @@ def init_parallel_groups(self): # run initialization of different process groups for initializer_cfg in pg_init: cfg = initializer_cfg.copy() - initializer_type = cfg.pop('type') - initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(rank, world_size, self.config, - self.data_parallel_size, - self.pipeline_parallel_size, - self.tensor_parallel_size, **cfg) + initializer_type = cfg.pop("type") + initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)( + rank, + world_size, + self.config, + self.data_parallel_size, + self.pipeline_parallel_size, + self.tensor_parallel_size, + **cfg, + ) parallel_setting = initializer.init_dist_group() if isinstance(parallel_setting, list): for args in parallel_setting: @@ -497,8 +506,7 @@ def is_initialized(self, parallel_mode: ParallelMode): return parallel_mode in self._groups def destroy(self): - """Destroys the current distributed parallel environment. - """ + """Destroys the current distributed parallel environment.""" for mode, group in self._groups.items(): if mode is not ParallelMode.GLOBAL: dist.destroy_process_group(group) @@ -519,7 +527,7 @@ def set_device(self, device_ordinal: int = None): torch.cuda.set_device(device_ordinal) if self._verbose: - self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}') + self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}") def set_seed(self, seed: int): """Sets seeds for all random libraries. @@ -552,21 +560,25 @@ def set_seed(self, seed: int): set_mode(ParallelMode.DATA) seeds = get_seeds() - seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()]) + seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) if self._verbose: - self._logger.info(f"initialized seed on rank {global_rank}, " - f"numpy: {seed}, python random: {seed}, {seed_str}," - f"the default parallel seed is {ParallelMode.DATA}.") + self._logger.info( + f"initialized seed on rank {global_rank}, " + f"numpy: {seed}, python random: {seed}, {seed_str}," + f"the default parallel seed is {ParallelMode.DATA}." + ) else: if self._verbose: self._logger.info( f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, pytorch: {seed}", - ranks=[0]) + ranks=[0], + ) self._logger.info( - 'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states', - ranks=[0]) + "WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states", + ranks=[0], + ) def set_virtual_pipeline_parallel_size(self, size): self.virtual_pipeline_parallel_size = size diff --git a/colossalai/legacy/context/parallel_mode.py b/colossalai/legacy/context/parallel_mode.py index 1cf6fa53dc1e..ceb52ff20da7 100644 --- a/colossalai/legacy/context/parallel_mode.py +++ b/colossalai/legacy/context/parallel_mode.py @@ -6,44 +6,43 @@ # parallel modes class ParallelMode(Enum): - """This is an enumeration class containing all possible parallel modes. - """ + """This is an enumeration class containing all possible parallel modes.""" - GLOBAL = 'global' + GLOBAL = "global" # common parallel - DATA = 'data' + DATA = "data" # model parallel - containing tensor and pipeline parallel groups # this is added to facilitate amp and grad clipping in hybrid parallel - MODEL = 'model' + MODEL = "model" # pipeline parallel - PIPELINE = 'pipe' + PIPELINE = "pipe" # containing all ranks in tensor parallel - TENSOR = 'tensor' + TENSOR = "tensor" # sequence parallel - SEQUENCE = 'sequence' - SEQUENCE_DP = 'sequence_dp' + SEQUENCE = "sequence" + SEQUENCE_DP = "sequence_dp" # 1D Parallel - PARALLEL_1D = '1d' + PARALLEL_1D = "1d" # 2D parallel - PARALLEL_2D_ROW = '2d_row' - PARALLEL_2D_COL = '2d_col' + PARALLEL_2D_ROW = "2d_row" + PARALLEL_2D_COL = "2d_col" # 3D parallel - PARALLEL_3D_INPUT = '3d_input' - PARALLEL_3D_WEIGHT = '3d_weight' - PARALLEL_3D_OUTPUT = '3d_output' + PARALLEL_3D_INPUT = "3d_input" + PARALLEL_3D_WEIGHT = "3d_weight" + PARALLEL_3D_OUTPUT = "3d_output" PARALLEL_3D_INPUT_X_WEIGHT = "3d_input_x_weight" PARALLEL_3D_OUTPUT_X_WEIGHT = "3d_output_x_weight" # 2.5D parallel - PARALLEL_2P5D_ROW = '2p5d_row' - PARALLEL_2P5D_COL = '2p5d_col' - PARALLEL_2P5D_DEP = '2p5d_dep' - PARALLEL_2P5D_XZ = '2p5d_xz' + PARALLEL_2P5D_ROW = "2p5d_row" + PARALLEL_2P5D_COL = "2p5d_col" + PARALLEL_2P5D_DEP = "2p5d_dep" + PARALLEL_2P5D_XZ = "2p5d_xz" diff --git a/colossalai/legacy/context/process_group_initializer/__init__.py b/colossalai/legacy/context/process_group_initializer/__init__.py index 48d52d7b9e52..a83165e40a8f 100644 --- a/colossalai/legacy/context/process_group_initializer/__init__.py +++ b/colossalai/legacy/context/process_group_initializer/__init__.py @@ -10,6 +10,14 @@ from .process_group_initializer import ProcessGroupInitializer __all__ = [ - 'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline', 'Initializer_Data', 'Initializer_2p5D', - 'Initializer_2D', 'Initializer_3D', 'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model' + "Initializer_Tensor", + "Initializer_Sequence", + "Initializer_Pipeline", + "Initializer_Data", + "Initializer_2p5D", + "Initializer_2D", + "Initializer_3D", + "Initializer_1D", + "ProcessGroupInitializer", + "Initializer_Model", ] diff --git a/colossalai/legacy/context/process_group_initializer/initializer_1d.py b/colossalai/legacy/context/process_group_initializer/initializer_1d.py index d853c6f06fc0..110a42cf880e 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_1d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_1d.py @@ -45,7 +45,7 @@ def init_dist_group(self): for i in range(self.num_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/legacy/context/process_group_initializer/initializer_2d.py b/colossalai/legacy/context/process_group_initializer/initializer_2d.py index 39f6a46890b6..1c08d4d4296a 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_2d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2d.py @@ -14,9 +14,10 @@ def _check_summa_env_var(summa_dim): env_summa_dim = env.summa_dim if env_summa_dim: - assert int(env_summa_dim) == summa_dim, \ - 'SUMMA_DIM has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' + assert int(env_summa_dim) == summa_dim, ( + "SUMMA_DIM has been set in the current environment and " + "does not match with the value passed to this initialized" + ) else: env.summa_dim = summa_dim @@ -57,7 +58,7 @@ def init_dist_group(self): for j in range(self.summa_dim): ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -106,7 +107,7 @@ def init_dist_group(self): for j in range(self.summa_dim): ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -137,8 +138,9 @@ def __init__(self, *args, **kwargs): self.num_group = self.world_size // self.tensor_parallel_size self.summa_dim = int(math.sqrt(self.tensor_parallel_size)) - assert self.tensor_parallel_size == self.summa_dim ** 2, \ - "2D summa dim should equal to tensor parallel size ^ 0.5" + assert ( + self.tensor_parallel_size == self.summa_dim**2 + ), "2D summa dim should equal to tensor parallel size ^ 0.5" _check_summa_env_var(self.summa_dim) self.col_initializer = Initializer_2D_Col(self.num_group, self.summa_dim, *args, **kwargs) diff --git a/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py index bb7a3509572f..b7d71b96334d 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_2p5d.py @@ -19,12 +19,14 @@ def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int): env_tesseract_dep = env.tesseract_dep if env_tesseract_dim and env_tesseract_dep: - assert int(env_tesseract_dim) == tesseract_dim, \ - 'TESSERACT_DIM has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' - assert int(env_tesseract_dep) == tesseract_dep, \ - 'TESSERACT_DEP has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' + assert int(env_tesseract_dim) == tesseract_dim, ( + "TESSERACT_DIM has been set in the current environment and " + "does not match with the value passed to this initialized" + ) + assert int(env_tesseract_dep) == tesseract_dep, ( + "TESSERACT_DEP has been set in the current environment and " + "does not match with the value passed to this initialized" + ) else: env.tesseract_dim = tesseract_dim env.tesseract_dep = tesseract_dep @@ -50,8 +52,9 @@ def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): self.num_group = self.world_size // self.tensor_parallel_size self.tesseract_dep = tesseract_dep self.tesseract_dim = tesseract_dim - assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ - "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" + assert ( + self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep + ), "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" def init_dist_group(self): """Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu. @@ -75,7 +78,7 @@ def init_dist_group(self): for i in range(self.tesseract_dim) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -129,7 +132,7 @@ def init_dist_group(self): for j in range(self.tesseract_dim) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -183,7 +186,7 @@ def init_dist_group(self): for k in range(self.tesseract_dep) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -238,7 +241,7 @@ def init_dist_group(self): for j in range(self.tesseract_dim) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -265,16 +268,25 @@ class Initializer_2p5D(ProcessGroupInitializer): depth (int): The depth of 2.5d parallel. """ - def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, - tensor_parallel_size: int, depth: int): + def __init__( + self, + rank: int, + world_size: int, + config: Config, + data_parallel_size: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + depth: int, + ): args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size) super().__init__(*args) self.num_group = self.world_size // self.tensor_parallel_size self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth)) self.tesseract_dep = depth - assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \ - "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" + assert ( + self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep + ), "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5" _check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep) self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args) @@ -293,6 +305,6 @@ def init_dist_group(self): self.col_initializer.init_dist_group(), self.row_initializer.init_dist_group(), self.dep_initializer.init_dist_group(), - self.xz_initializer.init_dist_group() + self.xz_initializer.init_dist_group(), ] return parallel_setting diff --git a/colossalai/legacy/context/process_group_initializer/initializer_3d.py b/colossalai/legacy/context/process_group_initializer/initializer_3d.py index 3dfbf5223b12..5f96405e90aa 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_3d.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_3d.py @@ -17,9 +17,10 @@ def _check_depth_env_var(depth): env_depth = env.depth_3d if env_depth: - assert int(env_depth) == depth, \ - 'DEPTH_3D has been set in the current environment and ' \ - 'does not match with the value passed to this initialized' + assert int(env_depth) == depth, ( + "DEPTH_3D has been set in the current environment and " + "does not match with the value passed to this initialized" + ) else: env.depth_3d = depth @@ -63,7 +64,7 @@ def init_dist_group(self): for k in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -114,7 +115,7 @@ def init_dist_group(self): for j in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -165,7 +166,7 @@ def init_dist_group(self): for j in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -219,7 +220,7 @@ def init_dist_group(self): for i in range(self.depth) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -273,7 +274,7 @@ def init_dist_group(self): for i in range(self.depth) ] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -302,8 +303,9 @@ def __init__(self, *args): super().__init__(*args) self.num_group = self.world_size // self.tensor_parallel_size self.depth = round(math.pow(self.tensor_parallel_size, 1 / 3)) - assert self.tensor_parallel_size == self.depth ** 3, \ - f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' + assert ( + self.tensor_parallel_size == self.depth**3 + ), f"3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})" _check_depth_env_var(self.depth) self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args) @@ -324,6 +326,6 @@ def init_dist_group(self): self.weight_initializer.init_dist_group(), self.output_initializer.init_dist_group(), self.input_x_weight_initializer.init_dist_group(), - self.output_x_weight_initializer.init_dist_group() + self.output_x_weight_initializer.init_dist_group(), ] return parallel_setting diff --git a/colossalai/legacy/context/process_group_initializer/initializer_data.py b/colossalai/legacy/context/process_group_initializer/initializer_data.py index b9dec4541dad..9c8bcf353c20 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_data.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_data.py @@ -43,7 +43,7 @@ def init_dist_group(self): for i in range(self.num_data_parallel_group): ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/legacy/context/process_group_initializer/initializer_model.py b/colossalai/legacy/context/process_group_initializer/initializer_model.py index 614ba372fbcc..6aeae27756e7 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_model.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_model.py @@ -45,7 +45,7 @@ def init_dist_group(self): for i in range(self.num_group): ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py b/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py index e093333ad18a..3e69be75ff7e 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_pipeline.py @@ -38,10 +38,11 @@ def init_dist_group(self): for i in range(self.data_parallel_size): for j in range(self.pipeline_stage_size): pipe_ranks = list( - range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size)) + range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size) + ) pipe_group_size = len(pipe_ranks) pipe_group = dist.new_group(pipe_ranks) - group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group + group_cpu = dist.new_group(pipe_ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group if self.rank in pipe_ranks: local_rank = pipe_ranks.index(self.rank) @@ -50,7 +51,16 @@ def init_dist_group(self): cpu_group = group_cpu ranks_in_group = pipe_ranks dist_settings.append( - tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, - ParallelMode.PIPELINE))) + tuple( + ( + local_rank, + group_world_size, + process_group, + cpu_group, + ranks_in_group, + ParallelMode.PIPELINE, + ) + ) + ) return dist_settings diff --git a/colossalai/legacy/context/process_group_initializer/initializer_sequence.py b/colossalai/legacy/context/process_group_initializer/initializer_sequence.py index a6e26b6bcaa9..638b6d5ef2a6 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_sequence.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_sequence.py @@ -46,7 +46,7 @@ def init_dist_group(self): for i in range(self.num_group): ranks = [i * self.dp_size + j for j in range(self.dp_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) @@ -91,8 +91,14 @@ def init_dist_group(self): parallel_setting = [] - local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode = \ - self._sequence_initializer.init_dist_group() + ( + local_rank, + group_world_size, + process_group, + cpu_group, + ranks_in_group, + mode, + ) = self._sequence_initializer.init_dist_group() # change mode to sequence mode = ParallelMode.SEQUENCE diff --git a/colossalai/legacy/context/process_group_initializer/initializer_tensor.py b/colossalai/legacy/context/process_group_initializer/initializer_tensor.py index 3be89e52a812..cb19a43bd373 100644 --- a/colossalai/legacy/context/process_group_initializer/initializer_tensor.py +++ b/colossalai/legacy/context/process_group_initializer/initializer_tensor.py @@ -43,7 +43,7 @@ def init_dist_group(self): for i in range(self.num_tensor_parallel_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] group = dist.new_group(ranks) - group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group if self.rank in ranks: local_rank = ranks.index(self.rank) diff --git a/colossalai/legacy/context/process_group_initializer/process_group_initializer.py b/colossalai/legacy/context/process_group_initializer/process_group_initializer.py index 98150ce8e428..98b5d7fc3882 100644 --- a/colossalai/legacy/context/process_group_initializer/process_group_initializer.py +++ b/colossalai/legacy/context/process_group_initializer/process_group_initializer.py @@ -18,8 +18,15 @@ class ProcessGroupInitializer(ABC): tensor_parallel_size (int): Size of tensor parallel. """ - def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, - tensor_parallel_size: int): + def __init__( + self, + rank: int, + world_size: int, + config: Config, + data_parallel_size: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ): self.rank = rank self.world_size = world_size self.data_parallel_size = data_parallel_size diff --git a/colossalai/legacy/context/random/__init__.py b/colossalai/legacy/context/random/__init__.py index d64b993257c1..5e8d82922ddc 100644 --- a/colossalai/legacy/context/random/__init__.py +++ b/colossalai/legacy/context/random/__init__.py @@ -13,6 +13,15 @@ ) __all__ = [ - 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', - 'sync_states', 'moe_set_seed', 'reset_seeds' + "seed", + "set_mode", + "with_seed", + "add_seed", + "get_seeds", + "get_states", + "get_current_mode", + "set_seed_states", + "sync_states", + "moe_set_seed", + "reset_seeds", ] diff --git a/colossalai/legacy/context/random/_helper.py b/colossalai/legacy/context/random/_helper.py index 4b5d5ef2fe55..be1d951d1229 100644 --- a/colossalai/legacy/context/random/_helper.py +++ b/colossalai/legacy/context/random/_helper.py @@ -100,7 +100,7 @@ def sync_states(): @contextmanager def seed(parallel_mode: ParallelMode): - """ A context for seed switch + """A context for seed switch Examples: @@ -162,6 +162,7 @@ def wrapper(*args, **kwargs): def moe_set_seed(seed): if torch.cuda.is_available(): from colossalai.legacy.core import global_context as gpc + global_rank = gpc.get_global_rank() diff_seed = seed + global_rank add_seed(ParallelMode.TENSOR, diff_seed, True) diff --git a/colossalai/legacy/context/random/seed_manager.py b/colossalai/legacy/context/random/seed_manager.py index b657ff7e1d32..c90e849631a1 100644 --- a/colossalai/legacy/context/random/seed_manager.py +++ b/colossalai/legacy/context/random/seed_manager.py @@ -42,7 +42,7 @@ def set_state(self, parallel_mode: ParallelMode, state: Tensor): Raises: AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager. """ - assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager' + assert parallel_mode in self._seed_states, f"Parallel mode {parallel_mode} is not found in the seed manager" self._seed_states[parallel_mode] = state def set_mode(self, parallel_mode: ParallelMode): @@ -71,9 +71,9 @@ def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = Fal AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode` or the seed for `parallel_mode` has been added. """ - assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' + assert isinstance(parallel_mode, ParallelMode), "A valid ParallelMode must be provided" if overwrite is False: - assert parallel_mode not in self._seed_states, f'The seed for {parallel_mode} has been added' + assert parallel_mode not in self._seed_states, f"The seed for {parallel_mode} has been added" elif parallel_mode in self._seed_states: print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True) diff --git a/colossalai/legacy/core.py b/colossalai/legacy/core.py index 0aaf1ee47730..80b6e4d25bd2 100644 --- a/colossalai/legacy/core.py +++ b/colossalai/legacy/core.py @@ -3,4 +3,4 @@ from colossalai.legacy.context.parallel_context import global_context -__all__ = ['global_context'] +__all__ = ["global_context"] diff --git a/colossalai/legacy/engine/__init__.py b/colossalai/legacy/engine/__init__.py index 158796befb31..581760748a16 100644 --- a/colossalai/legacy/engine/__init__.py +++ b/colossalai/legacy/engine/__init__.py @@ -1,4 +1,4 @@ from ._base_engine import Engine from .gradient_handler import * -__all__ = ['Engine'] +__all__ = ["Engine"] diff --git a/colossalai/legacy/engine/_base_engine.py b/colossalai/legacy/engine/_base_engine.py index 930caf20c1dd..0954e2be3eb1 100644 --- a/colossalai/legacy/engine/_base_engine.py +++ b/colossalai/legacy/engine/_base_engine.py @@ -59,15 +59,17 @@ class Engine: `Run resnet cifar10 with engine `_. """ - def __init__(self, - model: Module, - optimizer: "OptimizerWrapper", - criterion: Optional[_Loss] = None, - gradient_handlers: Optional[List[BaseGradientHandler]] = None, - clip_grad_norm: float = 0.0, - ophook_list: Optional[List[BaseOpHook]] = None, - verbose: bool = True, - schedule: Optional[BaseSchedule] = None): + def __init__( + self, + model: Module, + optimizer: "OptimizerWrapper", + criterion: Optional[_Loss] = None, + gradient_handlers: Optional[List[BaseGradientHandler]] = None, + clip_grad_norm: float = 0.0, + ophook_list: Optional[List[BaseOpHook]] = None, + verbose: bool = True, + schedule: Optional[BaseSchedule] = None, + ): self._model = model self._optimizer = optimizer self._criterion = criterion @@ -76,7 +78,7 @@ def __init__(self, self._logger = get_dist_logger() # state - self.training = True # default + self.training = True # default # build gradient handler if gradient_handlers: @@ -91,8 +93,9 @@ def __init__(self, # build schedule if schedule: - assert isinstance(schedule, BaseSchedule), \ - f'expected schedule to be of type BaseSchedule, but got {type(schedule)}' + assert isinstance( + schedule, BaseSchedule + ), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}" self._schedule = schedule else: self._schedule = NonPipelineSchedule() @@ -149,13 +152,11 @@ def remove_hook(self, ophook: Type[BaseOpHook]) -> None: logger.warning(f"removing hooks is currently not supported") def zero_grad(self): - """Set the gradient of parameters to zero - """ + """Set the gradient of parameters to zero""" self.optimizer.zero_grad() def step(self): - """Execute parameter update - """ + """Execute parameter update""" self._all_reduce_gradients() self.optimizer.clip_grad_by_norm(self._clip_grad_norm) return self.optimizer.step() @@ -192,8 +193,7 @@ def __call__(self, *args, **kwargs): return self.model(*args, **kwargs) def _all_reduce_gradients(self): - """Handles all-reduce operations of gradients across different parallel groups. - """ + """Handles all-reduce operations of gradients across different parallel groups.""" for handler in self._gradient_handlers: handler.handle_gradient() @@ -208,13 +208,11 @@ def execute_schedule(self, data_iter: Iterable, **kwargs): return output, label, loss def train(self): - """Sets the model to training mode. - """ + """Sets the model to training mode.""" self.training = True self._model.train() def eval(self): - """Sets the model to evaluation mode. - """ + """Sets the model to evaluation mode.""" self.training = False self._model.eval() diff --git a/colossalai/legacy/engine/gradient_accumulation/__init__.py b/colossalai/legacy/engine/gradient_accumulation/__init__.py index 670c26d06e55..e0835318ed9f 100644 --- a/colossalai/legacy/engine/gradient_accumulation/__init__.py +++ b/colossalai/legacy/engine/gradient_accumulation/__init__.py @@ -14,17 +14,22 @@ ) __all__ = [ - 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep', - 'GradAccumGradientHandler' + "accumulate_gradient", + "GradAccumDataloader", + "GradAccumOptimizer", + "GradAccumLrSchedulerByStep", + "GradAccumGradientHandler", ] -def accumulate_gradient(model: nn.Module, - optimizer: Optimizer, - dataloader: Iterable, - accumulate_size: int, - gradient_handlers: List[BaseGradientHandler] = None, - lr_scheduler: _LRScheduler = None): +def accumulate_gradient( + model: nn.Module, + optimizer: Optimizer, + dataloader: Iterable, + accumulate_size: int, + gradient_handlers: List[BaseGradientHandler] = None, + lr_scheduler: _LRScheduler = None, +): r"""Turning model, optimizer, dataloader into corresponding object for gradient accumulation. Args: diff --git a/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py index c2270dc53a50..9de0f6c0ffd9 100644 --- a/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/legacy/engine/gradient_accumulation/_gradient_accumulation.py @@ -272,8 +272,9 @@ class GradAccumGradientHandler: """ def __init__(self, grad_handler: BaseGradientHandler, accumulate_size: int) -> None: - assert isinstance(grad_handler, BaseGradientHandler), \ - f'expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}' + assert isinstance( + grad_handler, BaseGradientHandler + ), f"expected grad_handler to be type BaseGradientHandler, but got {type(grad_handler)}" self.grad_handler = grad_handler self.accumulate_size = accumulate_size self.accumulate_step = 0 diff --git a/colossalai/legacy/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py index 2dea768bad7e..78928b138842 100644 --- a/colossalai/legacy/engine/gradient_handler/__init__.py +++ b/colossalai/legacy/engine/gradient_handler/__init__.py @@ -6,6 +6,10 @@ from ._zero_gradient_handler import ZeROGradientHandler __all__ = [ - 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', - 'MoeGradientHandler', 'SequenceParallelGradientHandler' + "BaseGradientHandler", + "DataParallelGradientHandler", + "ZeROGradientHandler", + "PipelineSharedModuleGradientHandler", + "MoeGradientHandler", + "SequenceParallelGradientHandler", ] diff --git a/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py index 7d96dd8a88a6..e594bb00f96b 100644 --- a/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_base_gradient_handler.py @@ -22,4 +22,3 @@ def handle_gradient(self): """A method to accumulate gradients across different parallel groups. Users should write their own functions or just use the functions in pre-defined subclasses. """ - pass diff --git a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py index c692ee903442..3782adaf7187 100644 --- a/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_data_parallel_gradient_handler.py @@ -20,8 +20,7 @@ class DataParallelGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in a data parallel group. - """ + """A method running a all-reduce operation in a data parallel group.""" # TODO: add memory buffer if gpc.data_parallel_size > 1: bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA)) diff --git a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py index e7a6df2d8ae8..6a7224cff7bd 100644 --- a/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_moe_gradient_handler.py @@ -42,5 +42,6 @@ def handle_gradient(self): for ep_size in epsize_param_dict: if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - bucket_allreduce(param_list=epsize_param_dict[ep_size], - group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) + bucket_allreduce( + param_list=epsize_param_dict[ep_size], group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group + ) diff --git a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py index 3eae7d58ac95..3a65f65abf73 100644 --- a/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -26,17 +26,21 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in sub pipeline parallel groups. - """ + """A method running a all-reduce operation in sub pipeline parallel groups.""" if gpc.pipeline_parallel_size > 1: # bucketize and all-reduce buckets = defaultdict(lambda: defaultdict(list)) # Pack the buckets. for param in self._model.parameters(): - group = getattr(param, 'pipeline_shared_module_pg', None) - if param.requires_grad and group is not None and ( - (hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null()) - or param.grad is not None): + group = getattr(param, "pipeline_shared_module_pg", None) + if ( + param.requires_grad + and group is not None + and ( + (hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null()) + or param.grad is not None + ) + ): tp = param.data.type() buckets[group][tp].append(param) @@ -44,7 +48,7 @@ def handle_gradient(self): for group, group_buckets in buckets.items(): for tp, bucket in group_buckets.items(): grads = [ - param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data + param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data for param in bucket ] coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device()) diff --git a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py index 38b7f5993b73..6d507bcc0269 100644 --- a/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_sequence_parallel_gradient_handler.py @@ -20,7 +20,6 @@ class SequenceParallelGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in a data parallel group. - """ + """A method running a all-reduce operation in a data parallel group.""" if gpc.get_world_size(ParallelMode.SEQUENCE_DP) > 1: bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.SEQUENCE_DP)) diff --git a/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py index 4ca7cd0b0702..63ec6e70ba06 100644 --- a/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py +++ b/colossalai/legacy/engine/gradient_handler/_zero_gradient_handler.py @@ -16,6 +16,5 @@ class ZeROGradientHandler(BaseGradientHandler): """ def handle_gradient(self): - """A method running a all-reduce operation in a data parallel group. - """ + """A method running a all-reduce operation in a data parallel group.""" self._optimizer.sync_grad() diff --git a/colossalai/legacy/engine/schedule/__init__.py b/colossalai/legacy/engine/schedule/__init__.py index 0f2c039d7057..017231a9b4a8 100644 --- a/colossalai/legacy/engine/schedule/__init__.py +++ b/colossalai/legacy/engine/schedule/__init__.py @@ -2,4 +2,4 @@ from ._non_pipeline_schedule import NonPipelineSchedule from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape -__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape'] +__all__ = ["BaseSchedule", "NonPipelineSchedule", "PipelineSchedule", "InterleavedPipelineSchedule", "get_tensor_shape"] diff --git a/colossalai/legacy/engine/schedule/_base_schedule.py b/colossalai/legacy/engine/schedule/_base_schedule.py index 7505a3eb20e3..4a3ccfda1bb5 100644 --- a/colossalai/legacy/engine/schedule/_base_schedule.py +++ b/colossalai/legacy/engine/schedule/_base_schedule.py @@ -47,7 +47,8 @@ def _move_to_device(self, data): data = {k: self._move_tensor(v) for k, v in data.items()} else: raise TypeError( - f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}" + ) return data def _get_batch_size(self, data): @@ -72,7 +73,7 @@ def load_batch(self, data_iter, to_gpu=True): Tuple (:class:`Tensor`, :class:`torch.Tensor`): A tuple of (data, label). """ if data_iter is None: - raise RuntimeError('Dataloader is not defined.') + raise RuntimeError("Dataloader is not defined.") batch_data = next(data_iter) if to_gpu: @@ -81,17 +82,17 @@ def load_batch(self, data_iter, to_gpu=True): return batch_data def pre_processing(self, engine): - """To perform actions before running the schedule. - """ - pass + """To perform actions before running the schedule.""" @abstractmethod - def forward_backward_step(self, - engine, - data_iter: Iterable, - forward_only: bool, - return_loss: bool = True, - return_output_label: bool = True): + def forward_backward_step( + self, + engine, + data_iter: Iterable, + forward_only: bool, + return_loss: bool = True, + return_output_label: bool = True, + ): """The process function over a batch of dataset for training or evaluation. Args: @@ -101,7 +102,6 @@ def forward_backward_step(self, return_loss (bool, optional): If False, the loss won't be returned. return_output_label (bool, optional): If False, the output and label won't be returned. """ - pass @staticmethod def _call_engine(engine, inputs): @@ -113,13 +113,14 @@ def _call_engine(engine, inputs): return engine(**inputs) else: TypeError( - f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}") + f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}" + ) @staticmethod def _call_engine_criterion(engine, outputs, labels): - assert isinstance(outputs, - (torch.Tensor, list, tuple, - dict)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' + assert isinstance( + outputs, (torch.Tensor, list, tuple, dict) + ), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}" if isinstance(outputs, torch.Tensor): outputs = (outputs,) if isinstance(labels, torch.Tensor): @@ -134,6 +135,8 @@ def _call_engine_criterion(engine, outputs, labels): elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)): raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}") else: - raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \ + raise TypeError( + f"Expected model outputs and labels to be of type torch.Tensor ' \ '(which is auto-converted to tuple), list, tuple, or dict, ' \ - 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)") + 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)" + ) diff --git a/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py index b67893c1a0bb..08c6cfd60f28 100644 --- a/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_non_pipeline_schedule.py @@ -37,19 +37,22 @@ def __init__(self, data_process_func: Callable = None): if data_process_func: sig = inspect.signature(data_process_func) - assert len(sig.parameters) == 1, \ - 'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \ - 'which is a tuple of tensors for the current batch, ' \ - 'i.e. data_process_func(dataloader_output).' + assert len(sig.parameters) == 1, ( + "The data_process_func only takes in one parameter for NonPipelineSchedule, " + "which is a tuple of tensors for the current batch, " + "i.e. data_process_func(dataloader_output)." + ) super().__init__(data_process_func) - def forward_backward_step(self, - engine, - data_iter: Iterable, - forward_only: bool = False, - return_loss: bool = True, - return_output_label: bool = True): + def forward_backward_step( + self, + engine, + data_iter: Iterable, + forward_only: bool = False, + return_loss: bool = True, + return_output_label: bool = True, + ): """The process function that loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. @@ -64,8 +67,9 @@ def forward_backward_step(self, Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ - assert forward_only or return_loss, \ - "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." batch_data = self.load_batch(data_iter) if self.data_process_func: data, label = self.data_process_func(batch_data) diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py index 37eed82f8a28..4fc5040f6983 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -18,14 +18,18 @@ def get_tensor_shape(): - if hasattr(gpc.config, 'TENSOR_SHAPE'): + if hasattr(gpc.config, "TENSOR_SHAPE"): return gpc.config.TENSOR_SHAPE if not gpc.is_initialized(ParallelMode.PIPELINE): return None - if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr( - gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'): + if ( + hasattr(gpc.config, "SEQ_LENGTH") + and hasattr(gpc.config, "GLOBAL_BATCH_SIZE") + and hasattr(gpc.config, "GLOBAL_BATCH_SIZE") + and hasattr(gpc.config, "HIDDEN_SIZE") + ): if gpc.is_initialized(ParallelMode.DATA): dp_size = gpc.get_world_size(ParallelMode.DATA) else: @@ -35,8 +39,11 @@ def get_tensor_shape(): else: seq_size = 1 - tensor_shape = (gpc.config.SEQ_LENGTH // seq_size, - gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE) + tensor_shape = ( + gpc.config.SEQ_LENGTH // seq_size, + gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, + gpc.config.HIDDEN_SIZE, + ) return tensor_shape else: return None @@ -49,7 +56,7 @@ def pack_return_tensors(return_tensors): elif isinstance(output[0], (list, tuple)): output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) else: - raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + raise TypeError(f"Output of model must be tensor or list/tuple of tensors") if isinstance(label[0], torch.Tensor): label = torch.cat(label, dim=0) else: @@ -88,28 +95,31 @@ def data_process_func(stage_output, dataloader_output): """ - def __init__(self, - num_microbatches, - data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, - scatter_gather_tensors: bool = False): - + def __init__( + self, + num_microbatches, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False, + ): # we need to make sure that the signature of the data_process_func is valid if data_process_func: sig = inspect.signature(data_process_func) - assert len(sig.parameters) == 2, \ - 'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \ - 'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \ - 'i.e. data_process_func(stage_output, dataloader_output).' + assert len(sig.parameters) == 2, ( + "The data_process_func only takes in two parameters for NonPipelineSchedule, " + "which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, " + "i.e. data_process_func(stage_output, dataloader_output)." + ) super().__init__(data_process_func=data_process_func) - assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}' + assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}" self.num_microbatches = num_microbatches self.dtype = torch.float - assert not isinstance(tensor_shape, - int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." + assert not isinstance( + tensor_shape, int + ), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]." if tensor_shape is None: self.tensor_shape = tensor_shape elif isinstance(tensor_shape, torch.Size): @@ -128,26 +138,25 @@ def load_batch(self, data_iter): # Pipeline schedule just puts data in memory batch_data = super().load_batch(data_iter, to_gpu=False) self.microbatch_offset = 0 - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" + assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" self.microbatch_size = self.batch_size // self.num_microbatches self.batch_data = batch_data def _get_data_slice(self, data, offset): if isinstance(data, torch.Tensor): - return data[offset:offset + self.microbatch_size] + return data[offset : offset + self.microbatch_size] elif isinstance(data, (list, tuple)): data_dict = {} for element in data: if isinstance(element, dict): - data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()}) + data_dict.update({k: v[offset : offset + self.microbatch_size] for k, v in element.items()}) elif data_dict: - data_dict['label'] = element[offset:offset + self.microbatch_size] + data_dict["label"] = element[offset : offset + self.microbatch_size] if data_dict: return data_dict - return [val[offset:offset + self.microbatch_size] for val in data] + return [val[offset : offset + self.microbatch_size] for val in data] elif isinstance(data, dict): - return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} + return {k: v[offset : offset + self.microbatch_size] for k, v in data.items()} else: raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") @@ -180,8 +189,8 @@ def _call_engine(model, data): return model(*data) elif isinstance(data, dict): stage_output = None - if 'stage_output' in data: - stage_output = data.pop('stage_output') + if "stage_output" in data: + stage_output = data.pop("stage_output") if stage_output is None: return model(**data) elif isinstance(stage_output, torch.Tensor): @@ -198,7 +207,7 @@ def _call_engine(model, data): def _get_actual_forward_func(self, module): if isinstance(module, NaiveAMPModel): sig = inspect.signature(module.model.forward) - elif hasattr(module, 'colo_attr'): + elif hasattr(module, "colo_attr"): sig = inspect.signature(module.module.forward) else: sig = inspect.signature(module.forward) @@ -221,9 +230,9 @@ def _get_data_label_for_current_step(self, stage_output, micro_batch_data, crite _, label = micro_batch_data elif isinstance(micro_batch_data, dict): data = {} - data['stage_output'] = stage_output - if 'label' in micro_batch_data: - label = micro_batch_data.pop('label') + data["stage_output"] = stage_output + if "label" in micro_batch_data: + label = micro_batch_data.pop("label") else: label = None load_data = micro_batch_data @@ -263,7 +272,7 @@ def _forward_step(self, engine, input_obj, return_tensors, return_output_label=T else: if isinstance(output_obj, torch.Tensor): self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}" ) return output_obj @@ -325,12 +334,13 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." self.load_batch(data_iter) - num_warmup_microbatches = \ - (gpc.get_world_size(ParallelMode.PIPELINE) - - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = ( + gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 + ) num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches @@ -354,14 +364,12 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo for i in range(num_warmup_microbatches): if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shapes = comm.recv_obj_meta(ft_shapes) - input_obj = comm.recv_forward(ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + input_obj = comm.recv_forward( + ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) if not gpc.is_last_rank(ParallelMode.PIPELINE): if isinstance(output_obj, torch.Tensor): bt_shapes = output_obj.shape @@ -382,32 +390,29 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo if num_microbatches_remaining > 0: if not gpc.is_first_rank(ParallelMode.PIPELINE): ft_shapes = comm.recv_obj_meta(ft_shapes) - input_obj = comm.recv_forward(ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.recv_forward( + ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) + last_iteration = i == (num_microbatches_remaining - 1) - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) if forward_only: comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) if not last_iteration: - input_obj = comm.recv_forward(ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.recv_forward( + ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) else: - output_obj_grad = comm.send_forward_recv_backward(output_obj, - bt_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + output_obj_grad = comm.send_forward_recv_backward( + output_obj, bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) # Add input_obj and output_obj to end of list. input_objs.append(input_obj) @@ -424,10 +429,9 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo input_obj = None comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) else: - input_obj = comm.send_backward_recv_forward(input_obj_grad, - ft_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.send_backward_recv_forward( + input_obj_grad, ft_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) # Run cooldown backward passes. if not forward_only: @@ -435,9 +439,9 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) - output_obj_grad = comm.recv_backward(bt_shapes, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + output_obj_grad = comm.recv_backward( + bt_shapes, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors + ) input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) @@ -451,13 +455,14 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo class InterleavedPipelineSchedule(PipelineSchedule): - - def __init__(self, - num_microbatches: int, - num_model_chunks: int, - data_process_func: Callable = None, - tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, - scatter_gather_tensors: bool = False): + def __init__( + self, + num_microbatches: int, + num_model_chunks: int, + data_process_func: Callable = None, + tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None, + scatter_gather_tensors: bool = False, + ): """A helper schedule class for pipeline parallelism running environment. It uses interleaved 1F1B strategy. Other properties are similar as :class:`NonPipelineSchedule`. @@ -471,20 +476,25 @@ def __init__(self, scatter_gather_tensors (bool, optional): If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. """ - assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ - 'num_microbatches must be an integer multiple of pipeline parallel world size' - assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \ - f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}' - super().__init__(num_microbatches, - data_process_func=data_process_func, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather_tensors) + assert ( + num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0 + ), "num_microbatches must be an integer multiple of pipeline parallel world size" + assert ( + isinstance(num_model_chunks, int) and num_model_chunks > 0 + ), f"expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}" + super().__init__( + num_microbatches, + data_process_func=data_process_func, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather_tensors, + ) gpc.set_virtual_pipeline_parallel_size(num_model_chunks) gpc.set_virtual_pipeline_parallel_rank(0) self.num_model_chunks = num_model_chunks def pre_processing(self, engine): from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 + if isinstance(engine.model, ShardedModelV2): self.dtype = torch.half elif isinstance(engine.model[0], NaiveAMPModel): @@ -494,7 +504,7 @@ def pre_processing(self, engine): model = model.model sig = inspect.signature(model.forward) for p in sig.parameters.values(): - assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' + assert p.kind != inspect.Parameter.VAR_POSITIONAL, "*args is not supported" def load_batch(self, data_iter): super().load_batch(data_iter) @@ -506,13 +516,9 @@ def load_micro_batch(self, model_chunk_id): self.microbatch_offset[model_chunk_id] += self.microbatch_size return self._move_to_device(data) - def _forward_step(self, - engine, - model_chunk_id, - input_obj, - return_tensors, - return_output_label=True, - accum_loss=None): + def _forward_step( + self, engine, model_chunk_id, input_obj, return_tensors, return_output_label=True, accum_loss=None + ): """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -528,8 +534,9 @@ def _forward_step(self, Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. """ micro_batch_data = self.load_micro_batch(model_chunk_id) - data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, - engine.model[model_chunk_id]) + data, label = self._get_data_label_for_current_step( + input_obj, micro_batch_data, engine.criterion, engine.model[model_chunk_id] + ) output_obj = self._call_engine(engine.model[model_chunk_id], data) @@ -546,7 +553,7 @@ def _forward_step(self, else: if isinstance(output_obj, torch.Tensor): self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + f"Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}" ) return output_obj @@ -566,8 +573,9 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. The loss would be returned only in the last stage. """ - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." self.load_batch(data_iter) model = engine.model input_objs = [[] for _ in range(len(model))] @@ -605,19 +613,17 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo num_warmup_microbatches = num_microbatches all_warmup_microbatches = True else: - num_warmup_microbatches = \ - (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining = \ - num_microbatches - num_warmup_microbatches + num_microbatches_remaining = num_microbatches - num_warmup_microbatches def get_model_chunk_id(microbatch_id, forward): """Helper method to get the model chunk ID given the iteration number.""" microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) model_chunk_id = microbatch_id_in_group // pipeline_parallel_size if not forward: - model_chunk_id = (num_model_chunks - model_chunk_id - 1) + model_chunk_id = num_model_chunks - model_chunk_id - 1 return model_chunk_id def _forward_step_helper(microbatch_id): @@ -629,16 +635,17 @@ def _forward_step_helper(microbatch_id): # forward step if gpc.is_pipeline_first_stage(): - if len(input_objs[model_chunk_id]) == \ - len(output_objs[model_chunk_id]): + if len(input_objs[model_chunk_id]) == len(output_objs[model_chunk_id]): input_objs[model_chunk_id].append(None) input_obj = input_objs[model_chunk_id][-1] - output_obj = self._forward_step(engine, - model_chunk_id, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, + model_chunk_id, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss, + ) output_objs[model_chunk_id].append(output_obj) # if forward-only, no need to save tensors for a backward pass @@ -670,8 +677,8 @@ def _backward_step_helper(microbatch_id): if not gpc.is_pipeline_first_stage(): input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0]) input_objs[0].append( - comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) + comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) + ) for k in range(num_warmup_microbatches): model_chunk_id = get_model_chunk_id(k, forward=True) @@ -683,8 +690,9 @@ def _backward_step_helper(microbatch_id): output_obj_shapes[model_chunk_id] = [] for out_tensor in output_obj: output_obj_shapes[model_chunk_id].append(out_tensor.shape) - send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj, - send_tensor_shape_flags[model_chunk_id]) + send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta( + output_obj, send_tensor_shape_flags[model_chunk_id] + ) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True @@ -701,34 +709,36 @@ def _backward_step_helper(microbatch_id): with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): if not gpc.is_pipeline_first_stage(): input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta( - input_obj_shapes[next_forward_model_chunk_id]) + input_obj_shapes[next_forward_model_chunk_id] + ) # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None - if k == (num_warmup_microbatches - 1) and not forward_only and \ - not all_warmup_microbatches: + if k == (num_warmup_microbatches - 1) and not forward_only and not all_warmup_microbatches: input_obj_grad = None recv_next = True if gpc.is_pipeline_last_stage(ignore_virtual=True): recv_next = False output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None - input_obj, output_obj_grad = \ - comm.send_forward_backward_recv_forward_backward( - output_obj, input_obj_grad, - input_shape, - output_shape, - recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( + output_obj, + input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) output_obj_grads[num_model_chunks - 1].append(output_obj_grad) else: - input_obj = \ - comm.send_forward_recv_forward( - output_obj, - input_shape, - recv_prev=recv_prev, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.send_forward_recv_forward( + output_obj, + input_shape, + recv_prev=recv_prev, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) input_objs[next_forward_model_chunk_id].append(input_obj) # Run 1F1B in steady state. @@ -771,8 +781,9 @@ def _backward_step_helper(microbatch_id): recv_next = True if gpc.is_pipeline_last_stage(ignore_virtual=True): # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1), - forward=False) + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) if next_backward_model_chunk_id == 0: recv_next = False next_backward_model_chunk_id -= 1 @@ -787,14 +798,16 @@ def _backward_step_helper(microbatch_id): input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None # Communicate objs. - input_obj, output_obj_grad = \ - comm.send_forward_backward_recv_forward_backward( - output_obj, input_obj_grad, - input_shape, - output_shape, - recv_prev=recv_prev, recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward( + output_obj, + input_obj_grad, + input_shape, + output_shape, + recv_prev=recv_prev, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) # Put input_obj and output_obj_grad in data structures in the # right location. @@ -807,8 +820,10 @@ def _backward_step_helper(microbatch_id): if not forward_only: if all_warmup_microbatches: output_obj_grads[num_model_chunks - 1].append( - comm.recv_backward(output_obj_shapes[num_model_chunks - 1], - scatter_gather_tensors=self.scatter_gather_tensors)) + comm.recv_backward( + output_obj_shapes[num_model_chunks - 1], scatter_gather_tensors=self.scatter_gather_tensors + ) + ) for k in range(num_microbatches_remaining, num_microbatches): input_obj_grad = _backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) @@ -820,11 +835,14 @@ def _backward_step_helper(microbatch_id): recv_next = False output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None output_obj_grads[next_backward_model_chunk_id].append( - comm.send_backward_recv_backward(input_obj_grad, - output_shape, - recv_next=recv_next, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors)) + comm.send_backward_recv_backward( + input_obj_grad, + output_shape, + recv_next=recv_next, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors, + ) + ) if len(return_tensors) > 0: output, label = pack_return_tensors(return_tensors) diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index bf8b599a81ae..867c3dfa819b 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -21,7 +21,7 @@ def pack_return_tensors(return_tensors): elif isinstance(output[0], (list, tuple)): output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) else: - raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + raise TypeError(f"Output of model must be tensor or list/tuple of tensors") if isinstance(label[0], torch.Tensor): label = torch.cat(label, dim=0) else: @@ -59,12 +59,9 @@ def data_process_func(stage_output, dataloader_output): """ - def forward_backward_step(self, - engine: Engine, - data_iter: Iterable, - forward_only=False, - return_loss=True, - return_output_label=True) -> Tuple[torch.Tensor]: + def forward_backward_step( + self, engine: Engine, data_iter: Iterable, forward_only=False, return_loss=True, return_output_label=True + ) -> Tuple[torch.Tensor]: """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -80,14 +77,15 @@ def forward_backward_step(self, Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ - assert forward_only or return_loss, \ - 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + assert ( + forward_only or return_loss + ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." self.load_batch(data_iter) # num_warmup_microbatches is the step when not all the processes are working - num_warmup_microbatches = \ - (gpc.get_world_size(ParallelMode.PIPELINE) - - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = ( + gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 + ) num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches @@ -109,11 +107,9 @@ def forward_backward_step(self, for i in range(num_warmup_microbatches): input_obj = comm.recv_forward() - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) comm.send_forward(output_obj) @@ -129,13 +125,11 @@ def forward_backward_step(self, # Run 1F1B in steady state. for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) + last_iteration = i == (num_microbatches_remaining - 1) - output_obj = self._forward_step(engine, - input_obj, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step( + engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss + ) if forward_only: comm.send_forward(output_obj) diff --git a/colossalai/legacy/global_variables.py b/colossalai/legacy/global_variables.py index 61b31965e2e6..93cd5e60fa61 100644 --- a/colossalai/legacy/global_variables.py +++ b/colossalai/legacy/global_variables.py @@ -12,19 +12,21 @@ def __new__(cls, *args, **kwargs): def __init__(self, *args, **kwargs): self.load(*args, **kwargs) - def load(self, - mode: Optional[str] = None, - vocab_parallel: bool = False, - parallel_input_1d: bool = False, - summa_dim: int = None, - tesseract_dim: int = None, - tesseract_dep: int = None, - depth_3d: int = None, - input_group_3d=None, - weight_group_3d=None, - output_group_3d=None, - input_x_weight_group_3d=None, - output_x_weight_group_3d=None): + def load( + self, + mode: Optional[str] = None, + vocab_parallel: bool = False, + parallel_input_1d: bool = False, + summa_dim: int = None, + tesseract_dim: int = None, + tesseract_dep: int = None, + depth_3d: int = None, + input_group_3d=None, + weight_group_3d=None, + output_group_3d=None, + input_x_weight_group_3d=None, + output_x_weight_group_3d=None, + ): self.mode = mode self.vocab_parallel = vocab_parallel self.parallel_input_1d = parallel_input_1d @@ -39,18 +41,20 @@ def load(self, self.output_x_weight_group_3d = output_x_weight_group_3d def save(self): - return dict(mode=self.mode, - vocab_parallel=self.vocab_parallel, - parallel_input_1d=self.parallel_input_1d, - summa_dim=self.summa_dim, - tesseract_dim=self.tesseract_dim, - tesseract_dep=self.tesseract_dep, - depth_3d=self.depth_3d, - input_group_3d=self.input_group_3d, - weight_group_3d=self.weight_group_3d, - output_group_3d=self.output_group_3d, - input_x_weight_group_3d=self.input_x_weight_group_3d, - output_x_weight_group_3d=self.output_x_weight_group_3d) + return dict( + mode=self.mode, + vocab_parallel=self.vocab_parallel, + parallel_input_1d=self.parallel_input_1d, + summa_dim=self.summa_dim, + tesseract_dim=self.tesseract_dim, + tesseract_dep=self.tesseract_dep, + depth_3d=self.depth_3d, + input_group_3d=self.input_group_3d, + weight_group_3d=self.weight_group_3d, + output_group_3d=self.output_group_3d, + input_x_weight_group_3d=self.input_x_weight_group_3d, + output_x_weight_group_3d=self.output_x_weight_group_3d, + ) tensor_parallel_env = TensorParallelEnv() diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py index 2c253adbaf38..ce9c626553bf 100644 --- a/colossalai/legacy/initialize.py +++ b/colossalai/legacy/initialize.py @@ -47,25 +47,27 @@ def get_default_parser(): Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser. """ parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, help='path to the config file') - parser.add_argument('--host', type=str, help='the master address for distributed training') - parser.add_argument('--port', type=int, help='the master port for distributed training') - parser.add_argument('--world_size', type=int, help='world size for distributed training') - parser.add_argument('--rank', type=int, help='rank for the default process group') - parser.add_argument('--local_rank', type=int, help='local rank on the node') - parser.add_argument('--backend', type=str, default='nccl', help='backend for distributed communication') + parser.add_argument("--config", type=str, help="path to the config file") + parser.add_argument("--host", type=str, help="the master address for distributed training") + parser.add_argument("--port", type=int, help="the master port for distributed training") + parser.add_argument("--world_size", type=int, help="world size for distributed training") + parser.add_argument("--rank", type=int, help="rank for the default process group") + parser.add_argument("--local_rank", type=int, help="local rank on the node") + parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") return parser -def launch(config: Union[str, Path, Config, Dict], - rank: int, - world_size: int, - host: str, - port: int, - backend: str = 'nccl', - local_rank: int = None, - seed: int = 1024, - verbose: bool = True): +def launch( + config: Union[str, Path, Config, Dict], + rank: int, + world_size: int, + host: str, + port: int, + backend: str = "nccl", + local_rank: int = None, + seed: int = 1024, + verbose: bool = True, +): """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input arguments are not given. Then initialize and set distributed environment by calling global_context's functions. @@ -88,8 +90,9 @@ def launch(config: Union[str, Path, Config, Dict], gpc.verbose = verbose # set config - assert isinstance(config, (Config, str, Path, dict)), \ - f'expected argument config to be Config, str or Path, but got {type(config)}' + assert isinstance( + config, (Config, str, Path, dict) + ), f"expected argument config to be Config, str or Path, but got {type(config)}" if not isinstance(config, Config) and isinstance(config, dict): config = Config(config) if isinstance(config, (str, Path)): @@ -115,18 +118,21 @@ def launch(config: Union[str, Path, Config, Dict], if verbose: logger = get_dist_logger() logger.info( - f'Distributed environment is initialized, ' - f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' - f'tensor parallel size: {gpc.tensor_parallel_size}', - ranks=[0]) - - -def launch_from_slurm(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + f"Distributed environment is initialized, " + f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, " + f"tensor parallel size: {gpc.tensor_parallel_size}", + ranks=[0], + ) + + +def launch_from_slurm( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables set by SLURM @@ -139,29 +145,33 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['SLURM_PROCID']) - world_size = int(os.environ['SLURM_NPROCS']) + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NPROCS"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the SLURM environment, visit https://www.colossalai.org/ for more information on launching with SLURM" ) - launch(config=config, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_openmpi(config: Union[str, Path, Config, Dict], - host: str, - port: int, - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_openmpi( + config: Union[str, Path, Config, Dict], + host: str, + port: int, + backend: str = "nccl", + seed: int = 1024, + verbose: bool = True, +): """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables set by OpenMPI @@ -174,29 +184,30 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['OMPI_COMM_WORLD_RANK']) - local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) - world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the OpenMPI environment, visit https://www.colossalai.org/ for more information on launching with OpenMPI" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def launch_from_torch(config: Union[str, Path, Config, Dict], - backend: str = 'nccl', - seed: int = 1024, - verbose: bool = True): + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def launch_from_torch( + config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True +): """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size from the environment variables set by PyTorch @@ -207,35 +218,39 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], verbose (bool, optional): Whether to print logs. Defaults to True. """ try: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + host = os.environ["MASTER_ADDR"] + port = int(os.environ["MASTER_PORT"]) except KeyError as e: raise RuntimeError( f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" ) - launch(config=config, - local_rank=local_rank, - rank=rank, - world_size=world_size, - host=host, - port=port, - backend=backend, - seed=seed, - verbose=verbose) - - -def initialize(model: nn.Module, - optimizer: Optimizer, - criterion: Optional[_Loss] = None, - train_dataloader: Optional[Iterable] = None, - test_dataloader: Optional[Iterable] = None, - lr_scheduler: Optional[_LRScheduler] = None, - ophooks: Optional[List[BaseOpHook]] = None, - verbose: bool = True) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: + launch( + config=config, + local_rank=local_rank, + rank=rank, + world_size=world_size, + host=host, + port=port, + backend=backend, + seed=seed, + verbose=verbose, + ) + + +def initialize( + model: nn.Module, + optimizer: Optimizer, + criterion: Optional[_Loss] = None, + train_dataloader: Optional[Iterable] = None, + test_dataloader: Optional[Iterable] = None, + lr_scheduler: Optional[_LRScheduler] = None, + ophooks: Optional[List[BaseOpHook]] = None, + verbose: bool = True, +) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]: """Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config. @@ -267,30 +282,30 @@ def initialize(model: nn.Module, f"\n========== Your Config ========\n" f"{pprint.pformat(gpc.config)}\n" f"================================\n", - ranks=[0]) + ranks=[0], + ) # cudnn - cudnn_benchmark = config.get('cudnn_benchmark', False) - cudnn_deterministic = config.get('cudnn_deterministic', False) + cudnn_benchmark = config.get("cudnn_benchmark", False) + cudnn_deterministic = config.get("cudnn_deterministic", False) torch.backends.cudnn.benchmark = cudnn_benchmark torch.backends.cudnn.deterministic = cudnn_deterministic if verbose: logger.info(f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) # zero - use_zero = hasattr(gpc.config, 'zero') + use_zero = hasattr(gpc.config, "zero") if use_zero: - zero_cfg = gpc.config.get('zero', None) + zero_cfg = gpc.config.get("zero", None) if zero_cfg is not None: cfg_ = zero_cfg.copy() else: cfg_ = {} - optimizer_config = zero_cfg.get('optimizer_config', None) - model_config = zero_cfg.get('model_config', None) - model, optimizer = convert_to_zero_v2(model, - optimizer, - model_config=model_config, - optimizer_config=optimizer_config) + optimizer_config = zero_cfg.get("optimizer_config", None) + model_config = zero_cfg.get("model_config", None) + model, optimizer = convert_to_zero_v2( + model, optimizer, model_config=model_config, optimizer_config=optimizer_config + ) logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) else: @@ -316,38 +331,38 @@ def initialize(model: nn.Module, logger.warning( "The parameters of models is not automatically synchronized.\n" "Please make sure that all parameters are the same in data parallel group.", - ranks=[0]) + ranks=[0], + ) # check amp and zero - fp16_cfg = gpc.config.get('fp16', None) + fp16_cfg = gpc.config.get("fp16", None) if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero: raise ConfigException( - "It is not allowed to set fp16 and zero configuration in your config file at the same time") + "It is not allowed to set fp16 and zero configuration in your config file at the same time" + ) # clip grad norm - clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0) + clip_grad_norm = gpc.config.get("clip_grad_norm", 0.0) # initialize amp amp_mode = None if fp16_cfg is not None and fp16_cfg.mode is not None: cfg_ = fp16_cfg.copy() - amp_mode = cfg_.pop('mode') + amp_mode = cfg_.pop("mode") if is_using_pp(): - assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently' + assert amp_mode == AMP_TYPE.NAIVE, "Pipeline only support NaiveAMP currently" if amp_mode == AMP_TYPE.NAIVE: - cfg_['clip_grad_norm'] = clip_grad_norm - model, optimizer, criterion = convert_to_amp(model=model, - optimizer=optimizer, - criterion=criterion, - mode=amp_mode, - amp_config=cfg_) + cfg_["clip_grad_norm"] = clip_grad_norm + model, optimizer, criterion = convert_to_amp( + model=model, optimizer=optimizer, criterion=criterion, mode=amp_mode, amp_config=cfg_ + ) # get torch ddp config - torch_ddp_cfg = gpc.config.get('torch_ddp', dict()) + torch_ddp_cfg = gpc.config.get("torch_ddp", dict()) # gradient handler - gradient_handler_cfg = gpc.config.get('gradient_handler', None) + gradient_handler_cfg = gpc.config.get("gradient_handler", None) if gradient_handler_cfg is None: # if gradient handler is not specified in the configuration file, # check in the following order @@ -355,54 +370,63 @@ def initialize(model: nn.Module, # 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp # 3. if using pipeline and dp size larger than 1, use data parallel grad handler if isinstance(optimizer, ShardedOptimizerV2): - gradient_handler_cfg = [dict(type='ZeROGradientHandler')] + gradient_handler_cfg = [dict(type="ZeROGradientHandler")] if verbose: logger.info( "Training with zero is detected, ZeROGradientHandler is automatically " "added even though not specified in the configuration", - ranks=[0]) + ranks=[0], + ) elif is_using_ddp() and MOE_CONTEXT.is_initialized: - gradient_handler_cfg = [dict(type='MoeGradientHandler')] + gradient_handler_cfg = [dict(type="MoeGradientHandler")] if verbose: logger.info( "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " "added even though not specified in the configuration", - ranks=[0]) + ranks=[0], + ) elif is_using_sequence(): - model = DDP(model, - process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), - device_ids=[torch.cuda.current_device()], - **torch_ddp_cfg) + model = DDP( + model, + process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg, + ) if verbose: - logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', - ranks=[0]) + logger.info( + "Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism", ranks=[0] + ) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: - model = DDP(model, - process_group=gpc.get_group(ParallelMode.DATA), - device_ids=[torch.cuda.current_device()], - **torch_ddp_cfg) + model = DDP( + model, + process_group=gpc.get_group(ParallelMode.DATA), + device_ids=[torch.cuda.current_device()], + **torch_ddp_cfg, + ) if verbose: - logger.info('Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) + logger.info("Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism", ranks=[0]) elif is_using_ddp(): - gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] + gradient_handler_cfg = [dict(type="DataParallelGradientHandler")] if verbose: logger.info( "Data parallel training is detected when using pipeline parallel, " "DataParallelGradientHandler is automatically " "added even though not specified in the configuration", - ranks=[0]) + ranks=[0], + ) # add pipeline parallel gradient handler, if pipeline shared module is detected for param in model.parameters(): - if getattr(param, 'pipeline_shared_module_pg', None) is not None: + if getattr(param, "pipeline_shared_module_pg", None) is not None: if gradient_handler_cfg is None: - gradient_handler_cfg = [dict(type='PipelineSharedModuleGradientHandler')] + gradient_handler_cfg = [dict(type="PipelineSharedModuleGradientHandler")] else: - gradient_handler_cfg.append(dict(type='PipelineSharedModuleGradientHandler')) + gradient_handler_cfg.append(dict(type="PipelineSharedModuleGradientHandler")) if verbose: logger.info( "pipeline_shared_module is detected, PipelineSharedModuleGradientHandler is automatically " "added even though not specified in the configuration", - ranks=[0]) + ranks=[0], + ) break else: if not isinstance(gradient_handler_cfg, list): @@ -418,7 +442,7 @@ def initialize(model: nn.Module, # initialize schedule for engine if is_using_pp(): tensor_shape = get_tensor_shape() - use_interleaved = hasattr(gpc.config, 'model') and hasattr(gpc.config.model, 'num_chunks') + use_interleaved = hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") if gpc.is_initialized(ParallelMode.PARALLEL_1D): scatter_gather = True else: @@ -426,14 +450,16 @@ def initialize(model: nn.Module, if use_interleaved: if isinstance(model, nn.Sequential): model = nn.ModuleList([model]) - schedule = InterleavedPipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - gpc.config.model.num_chunks, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather) + schedule = InterleavedPipelineSchedule( + gpc.config.NUM_MICRO_BATCHES, + gpc.config.model.num_chunks, + tensor_shape=tensor_shape, + scatter_gather_tensors=scatter_gather, + ) else: - schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES, - tensor_shape=tensor_shape, - scatter_gather_tensors=scatter_gather) + schedule = PipelineSchedule( + gpc.config.NUM_MICRO_BATCHES, tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather + ) else: schedule = NonPipelineSchedule() @@ -443,7 +469,8 @@ def initialize(model: nn.Module, logger.warning( "No PyTorch DDP or gradient handler is set up, please make sure you do not need " "to all-reduce the gradients after a training step.", - ranks=[0]) + ranks=[0], + ) else: gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg] @@ -452,7 +479,7 @@ def initialize(model: nn.Module, optimizer = OptimizerWrapper(optim=optimizer) # gradient accumulation - grad_accum_size = gpc.config.get('gradient_accumulation', None) + grad_accum_size = gpc.config.get("gradient_accumulation", None) if grad_accum_size is not None: optimizer, train_dataloader, gradient_handlers, lr_scheduler = accumulate_gradient( model=model, @@ -460,13 +487,16 @@ def initialize(model: nn.Module, dataloader=train_dataloader, accumulate_size=grad_accum_size, gradient_handlers=gradient_handlers, - lr_scheduler=lr_scheduler) - engine = Engine(model=model, - optimizer=optimizer, - criterion=criterion, - gradient_handlers=gradient_handlers, - clip_grad_norm=clip_grad_norm, - ophook_list=ophooks, - schedule=schedule) + lr_scheduler=lr_scheduler, + ) + engine = Engine( + model=model, + optimizer=optimizer, + criterion=criterion, + gradient_handlers=gradient_handlers, + clip_grad_norm=clip_grad_norm, + ophook_list=ophooks, + schedule=schedule, + ) return engine, train_dataloader, test_dataloader, lr_scheduler diff --git a/colossalai/legacy/nn/_ops/_utils.py b/colossalai/legacy/nn/_ops/_utils.py index a4228fa2116e..b6a99f855a4c 100644 --- a/colossalai/legacy/nn/_ops/_utils.py +++ b/colossalai/legacy/nn/_ops/_utils.py @@ -41,7 +41,7 @@ def _reduce(input_, pg: ProcessGroup): # skip if only one rank involved if pg.tp_world_size() == 1: return input_ - assert input_.device.type == 'cuda' + assert input_.device.type == "cuda" group = pg.tp_process_group() dist.all_reduce(input_, group=group) @@ -56,9 +56,10 @@ def _split(input_, pg: ProcessGroup, dim=-1): # Split along last dimension. dim_size = input_.size(dim) - assert dim_size % world_size == 0, \ - f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = pg.tp_local_rank() @@ -77,7 +78,7 @@ def _gather(input_, pg: ProcessGroup, dim=-1): rank = pg.tp_local_rank() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ - assert input_.device.type == 'cuda' + assert input_.device.type == "cuda" group = pg.tp_process_group() torch.distributed.all_gather(tensor_list, input_, group=group) @@ -203,7 +204,7 @@ def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: return x # TODO: enabling mpi backend to support CPU all_to_all - assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + assert x.device.type == "cuda", f"Currently, the collective function dual_all_to_all only supports nccl backend" shapes = list(x.size()) shapes[scatter_dim] = shapes[scatter_dim] // world_size @@ -216,7 +217,6 @@ def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: class _DualAllToAll(torch.autograd.Function): - @staticmethod def forward(ctx, x, pg, scatter_dim, gather_dim): ctx.scatter_dim = scatter_dim @@ -236,16 +236,14 @@ def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int): # table wise embedding shard -def _all_to_all_for_tablewise(x: torch.Tensor, - pg: ProcessGroup, - scatter_strides: List[int], - gather_strides: List[int], - forward=True) -> torch.Tensor: +def _all_to_all_for_tablewise( + x: torch.Tensor, pg: ProcessGroup, scatter_strides: List[int], gather_strides: List[int], forward=True +) -> torch.Tensor: world_size = pg.tp_world_size() rank = pg.tp_local_rank() if world_size == 1: return x - assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + assert x.device.type == "cuda", f"Currently, the collective function dual_all_to_all only supports nccl backend" if forward: scatter_list = list(x.split(scatter_strides, 0)) gather_list = [ @@ -266,7 +264,6 @@ def _all_to_all_for_tablewise(x: torch.Tensor, class _DualAllToAllForTablewise(torch.autograd.Function): - @staticmethod def forward(ctx, x, pg, scatter_strides, gather_strides): ctx.pg = pg @@ -276,8 +273,12 @@ def forward(ctx, x, pg, scatter_strides, gather_strides): @staticmethod def backward(ctx, grad): - return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, - forward=False), None, None, None + return ( + _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, forward=False), + None, + None, + None, + ) def dual_all_to_all_tablewise(x, pg, scatter_strides, gather_strides): diff --git a/colossalai/legacy/nn/layer/base_layer.py b/colossalai/legacy/nn/layer/base_layer.py index 01fd9b3e8943..66abc6fb1fd1 100644 --- a/colossalai/legacy/nn/layer/base_layer.py +++ b/colossalai/legacy/nn/layer/base_layer.py @@ -10,44 +10,54 @@ class ParallelLayer(nn.Module): - global_state_dict: bool = True def __init__(self): super().__init__() - self.data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank( - ParallelMode.DATA) - self.data_parallel_size = 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size( - ParallelMode.DATA) + self.data_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) + ) + self.data_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_world_size(ParallelMode.DATA) + ) - self.tensor_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank( - ParallelMode.TENSOR) - self.tensor_parallel_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size( - ParallelMode.TENSOR) + self.tensor_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_local_rank(ParallelMode.TENSOR) + ) + self.tensor_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) + ) - self.pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + self.pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + self.pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) - def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) + def _load_from_global_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def _save_to_global_state_dict(self, destination, prefix, keep_vars): return super()._save_to_state_dict(destination, prefix, keep_vars) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): if self.global_state_dict: if gpc.get_local_rank(ParallelMode.TENSOR) != 0: missing_keys.clear() unexpected_keys.clear() - return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, - unexpected_keys, error_msgs) - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) + return self._load_from_global_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def _save_to_state_dict(self, destination, prefix, keep_vars): if self.global_state_dict: diff --git a/colossalai/legacy/nn/layer/colossalai_layer/__init__.py b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py index ed743820ddbc..7c5449ff5578 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/__init__.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/__init__.py @@ -4,4 +4,4 @@ from .linear import Classifier, Linear from .normalization import LayerNorm -__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch'] +__all__ = ["Linear", "Classifier", "Embedding", "PatchEmbedding", "LayerNorm", "Dropout", "partition_batch"] diff --git a/colossalai/legacy/nn/layer/colossalai_layer/_utils.py b/colossalai/legacy/nn/layer/colossalai_layer/_utils.py index 677cb0e7ac42..98255142a846 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/_utils.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/_utils.py @@ -6,7 +6,7 @@ from ..parallel_3d._operation import split_batch_3d from ..utils import get_tensor_parallel_mode -_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d} +_parallel_split_batch = {"2d": split_batch_2d, "2.5d": split_batch_2p5d, "3d": split_batch_3d} def partition_batch(input_) -> Tensor: @@ -21,7 +21,6 @@ def partition_batch(input_) -> Tensor: class ColossalaiModule(nn.Module): - def __init__(self, module: nn.Module, **kwargs): super().__init__() self.module = module @@ -29,7 +28,7 @@ def __init__(self, module: nn.Module, **kwargs): setattr(self, k, v) def __getattr__(self, name: str): - if name == 'module': + if name == "module": return super().__getattr__(name) elif hasattr(self.module, name): return getattr(self.module, name) diff --git a/colossalai/legacy/nn/layer/colossalai_layer/dropout.py b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py index 7b0481a3f53c..ad6fcc2d8bf4 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/dropout.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/dropout.py @@ -24,7 +24,7 @@ def __init__(self, p: float = 0.5, inplace: bool = False) -> None: super().__init__(drop, tensor_parallel=tensor_parallel) def forward(self, *args): - if self.tensor_parallel in [None, '1d']: + if self.tensor_parallel in [None, "1d"]: return super().forward(*args) else: with seed(ParallelMode.TENSOR): diff --git a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py index 28bcb7ffefb0..e1db0fe98a02 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/embedding.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/embedding.py @@ -15,25 +15,25 @@ from ._utils import ColossalaiModule _parallel_embedding = { - '1d': Embedding1D, - '2d': Embedding2D, - '2.5d': Embedding2p5D, - '3d': Embedding3D, + "1d": Embedding1D, + "2d": Embedding2D, + "2.5d": Embedding2p5D, + "3d": Embedding3D, } _vocab_parallel_embedding = { - '1d': VocabParallelEmbedding1D, - '2d': VocabParallelEmbedding2D, - '2.5d': VocabParallelEmbedding2p5D, - '3d': VocabParallelEmbedding3D + "1d": VocabParallelEmbedding1D, + "2d": VocabParallelEmbedding2D, + "2.5d": VocabParallelEmbedding2p5D, + "3d": VocabParallelEmbedding3D, } _parallel_patchembedding = { None: VanillaPatchEmbedding, - '1d': PatchEmbedding1D, - '2d': PatchEmbedding2D, - '2.5d': PatchEmbedding2p5D, - '3d': PatchEmbedding3D + "1d": PatchEmbedding1D, + "2d": PatchEmbedding2D, + "2.5d": PatchEmbedding2p5D, + "3d": PatchEmbedding3D, } @@ -67,19 +67,24 @@ class Embedding(ColossalaiModule): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: dtype = None, - weight_initializer: Callable = init.normal_(), - vocab_parallel_limit: int = 2048, - *args, - **kwargs) -> None: + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: dtype = None, + weight_initializer: Callable = init.normal_(), + vocab_parallel_limit: int = 2048, + *args, + **kwargs, + ) -> None: tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is None: - embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, - **kwargs).to(dtype).to(get_current_device()) + embed = ( + nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs) + .to(dtype) + .to(get_current_device()) + ) weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) elif num_embeddings <= vocab_parallel_limit: embed = _parallel_embedding[tensor_parallel]( @@ -135,7 +140,7 @@ def __init__( flatten: bool = True, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_() + position_embed_initializer: Callable = init.zeros_(), ) -> None: tensor_parallel = get_tensor_parallel_mode() embed = _parallel_patchembedding[tensor_parallel]( diff --git a/colossalai/legacy/nn/layer/colossalai_layer/linear.py b/colossalai/legacy/nn/layer/colossalai_layer/linear.py index c05ceb66ce25..aa4863e28b81 100644 --- a/colossalai/legacy/nn/layer/colossalai_layer/linear.py +++ b/colossalai/legacy/nn/layer/colossalai_layer/linear.py @@ -5,7 +5,6 @@ from torch import dtype, nn from colossalai.nn import init -from colossalai.utils import get_current_device from ..parallel_1d import * from ..parallel_2d import * @@ -15,21 +14,21 @@ from ..vanilla import * from ._utils import ColossalaiModule -_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} +_parallel_linear = {None: VanillaLinear, "1d": Linear1D, "2d": Linear2D, "2.5d": Linear2p5D, "3d": Linear3D} _parallel_classifier = { None: VanillaClassifier, - '1d': Classifier1D, - '2d': Classifier2D, - '2.5d': Classifier2p5D, - '3d': Classifier3D + "1d": Classifier1D, + "2d": Classifier2D, + "2.5d": Classifier2p5D, + "3d": Classifier3D, } _vocab_parallel_classifier = { - '1d': VocabParallelClassifier1D, - '2d': VocabParallelClassifier2D, - '2.5d': VocabParallelClassifier2p5D, - '3d': VocabParallelClassifier3D + "1d": VocabParallelClassifier1D, + "2d": VocabParallelClassifier2D, + "2.5d": VocabParallelClassifier2p5D, + "3d": VocabParallelClassifier3D, } @@ -65,19 +64,21 @@ class Linear(ColossalaiModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - **kwargs) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs, + ) -> None: tensor_parallel = get_tensor_parallel_mode() linear_cls = _parallel_linear[tensor_parallel] - gather_output = kwargs.pop('gather_output', None) - if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available - kwargs['gather_output'] = gather_output + gather_output = kwargs.pop("gather_output", None) + if "gather_output" in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available + kwargs["gather_output"] = gather_output layer = linear_cls( in_features, out_features, @@ -108,15 +109,17 @@ class Classifier(ColossalaiModule): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - vocab_parallel_limit: int = 2048) -> None: + def __init__( + self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + vocab_parallel_limit: int = 2048, + ) -> None: tensor_parallel = get_tensor_parallel_mode() if num_classes <= vocab_parallel_limit or tensor_parallel is None: layer = _parallel_classifier[tensor_parallel]( diff --git a/colossalai/legacy/nn/layer/parallel_1d/__init__.py b/colossalai/legacy/nn/layer/parallel_1d/__init__.py index 9cffd4d339f5..35e9ec40d100 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_1d/__init__.py @@ -12,6 +12,14 @@ ) __all__ = [ - 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D', - 'VocabParallelEmbedding1D', 'LayerNorm1D', 'PatchEmbedding1D' + "Linear1D", + "Linear1D_Col", + "Linear1D_Row", + "Embedding1D", + "Dropout1D", + "Classifier1D", + "VocabParallelClassifier1D", + "VocabParallelEmbedding1D", + "LayerNorm1D", + "PatchEmbedding1D", ] diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py index db9dfa3667b4..f01da97ba39a 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -21,7 +21,7 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function): If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps: a value added to the denominator for numerical stability - """ + """ @staticmethod def forward(ctx, input, weight, bias, normalized_shape, eps): @@ -30,8 +30,9 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() - output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, - bias_, ctx.eps) + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @@ -39,11 +40,9 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = fused_mix_prec_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) return grad_input, grad_weight, grad_bias, None, None diff --git a/colossalai/legacy/nn/layer/parallel_1d/_utils.py b/colossalai/legacy/nn/layer/parallel_1d/_utils.py index 15b41e305cba..93b476e811a4 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_utils.py @@ -47,9 +47,10 @@ def _split(input_, parallel_mode, dim=-1): # Split along last dimension. dim_size = input_.size(dim) - assert dim_size % world_size == 0, \ - f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = gpc.get_local_rank(parallel_mode) diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index db7986b8e8e5..8304cd2e1eb7 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -27,7 +27,7 @@ from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule from ..utils import divide, set_tensor_parallel_attribute_by_partition -from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding +from ..vanilla import VanillaPatchEmbedding from ._operation import linear_with_async_comm from ._utils import ( gather_forward_split_backward, @@ -41,6 +41,7 @@ Fast_LN = None try: from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm except ImportError: pass @@ -67,33 +68,39 @@ class Linear1D(ColossalaiModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): parallel_input = get_parallel_input() if not parallel_input and not gather_output: - layer = Linear1D_Col(in_features, - out_features, - bias=bias, - dtype=dtype, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + layer = Linear1D_Col( + in_features, + out_features, + bias=bias, + dtype=dtype, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) else: - layer = Linear1D_Row(in_features, - out_features, - bias=bias, - dtype=dtype, - parallel_input=parallel_input, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + layer = Linear1D_Row( + in_features, + out_features, + bias=bias, + dtype=dtype, + parallel_input=parallel_input, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + ) super().__init__(layer) @@ -114,8 +121,30 @@ class LayerNorm1D(ColossalaiModule): """ _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, ] def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): @@ -125,6 +154,7 @@ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): norm = None try: from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) except ImportError: norm = LayerNorm(normalized_shape, eps=eps).to(dtype) @@ -132,8 +162,8 @@ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): def _load_from_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -171,14 +201,16 @@ class Classifier1D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -189,7 +221,7 @@ def __init__(self, # Parameters. # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -221,8 +253,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -235,50 +267,46 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args): if bias is not None: local_state[bias_key] = bias - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight if self.bias is not None: local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) input_ = input_ else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + assert ( + divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1] + ), "Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size + ) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) output_parallel = F.linear(input_, self.weight) @@ -307,15 +335,17 @@ class VocabParallelClassifier1D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -327,7 +357,7 @@ def __init__(self, # Parameters. # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False @@ -360,8 +390,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -374,43 +404,37 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args): if bias is not None: local_state[bias_key] = bias - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight if self.bias is not None: local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Set up backprop all-reduce. input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) # Matrix multiply. @@ -449,15 +473,17 @@ class Linear1D_Col(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() # Keep input parameters @@ -467,13 +493,13 @@ def __init__(self, self.skip_bias_add = skip_bias_add if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) # Parameters. # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) if bias: @@ -500,8 +526,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -513,41 +539,35 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args): if bias is not None: local_state[bias_key] = bias - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Set up backprop all-reduce. # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) input_parallel = input_ @@ -569,7 +589,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: @LAYERS.register_module class Linear1D_Row(ParallelLayer): - r""" Linear layer with row parallelism + r"""Linear layer with row parallelism Args: in_features (int): size of each input sample. @@ -588,16 +608,18 @@ class Linear1D_Row(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + ): super().__init__() self.stream_chunk_num = stream_chunk_num @@ -609,14 +631,14 @@ def __init__(self, self.skip_bias_add = skip_bias_add if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # Divide the weight matrix along the last dimension. self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) # Parameters. # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) if self.stream_chunk_num > 1: @@ -647,8 +669,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -660,48 +682,44 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args): if bias is not None: local_state[bias_key] = bias - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) + local_state = partition_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) input_ = input_ else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + assert ( + divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size + ) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) if self.stream_chunk_num > 1: @@ -712,9 +730,9 @@ def forward(self, input_: Tensor) -> Tensor: handle_list = [] for i in range(self.stream_chunk_num): output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=gpc.get_group(ParallelMode.PARALLEL_1D), - async_op=True) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=gpc.get_group(ParallelMode.PARALLEL_1D), async_op=True + ) handle_list.append(handle) # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) for handle in handle_list: @@ -763,14 +781,16 @@ class Embedding1D(ParallelLayer): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings @@ -782,7 +802,8 @@ def __init__(self, self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -804,31 +825,31 @@ def _fill_padding_idx_with_zero(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}) + local_state = partition_tensor_parallel_state_dict( + local_state, ParallelMode.PARALLEL_1D, dims={weight_key: -1}, partition_states={weight_key: True} + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) @@ -867,14 +888,16 @@ class VocabParallelEmbedding1D(ParallelLayer): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -889,7 +912,8 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -906,34 +930,38 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) if weight is not None: local_state[weight_key] = weight - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}) + local_state = partition_tensor_parallel_state_dict( + local_state, ParallelMode.PARALLEL_1D, dims={weight_key: 0}, partition_states={weight_key: True} + ) super()._load_from_global_state_dict(local_state, prefix, *args) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars) + local_state = gather_tensor_parallel_state_dict( + local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars, + ) destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: @@ -943,11 +971,12 @@ def forward(self, input_: Tensor) -> Tensor: masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) # Mask the output embedding. - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) return output @@ -1002,30 +1031,34 @@ class PatchEmbedding1D(ColossalaiModule): :type position_embed_initializer: typing.Callable, optional """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: torch.dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - embed = VanillaPatchEmbedding(img_size, - patch_size, - in_chans, - embed_size, - dtype=dtype, - flatten=flatten, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - position_embed_initializer=position_embed_initializer) + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: torch.dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): + embed = VanillaPatchEmbedding( + img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer, + ) super().__init__(embed) def _load_from_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() - param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] + param_keys = [prefix + "weight", prefix + "bias", prefix + "cls_token", prefix + "pos_embed"] if gpc.get_local_rank(ParallelMode.TENSOR) == 0: for key in param_keys: param = state_dict.pop(key, None) diff --git a/colossalai/legacy/nn/layer/parallel_2d/__init__.py b/colossalai/legacy/nn/layer/parallel_2d/__init__.py index 9c65f3608710..8d29c66b3a24 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_2d/__init__.py @@ -10,6 +10,13 @@ ) __all__ = [ - 'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', - 'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D' + "split_batch_2d", + "reduce_by_batch_2d", + "Linear2D", + "LayerNorm2D", + "Classifier2D", + "PatchEmbedding2D", + "Embedding2D", + "VocabParallelEmbedding2D", + "VocabParallelClassifier2D", ] diff --git a/colossalai/legacy/nn/layer/parallel_2d/_operation.py b/colossalai/legacy/nn/layer/parallel_2d/_operation.py index 43e14d4a47a5..f1eff7128e7a 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_operation.py @@ -5,10 +5,9 @@ from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd -from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce, reduce_scatter +from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.legacy.global_variables import tensor_parallel_env as env from colossalai.utils import get_current_device @@ -49,17 +48,30 @@ def matmul_2d( col_rank = gpc.get_local_rank(row_parallel_mode) data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = summa_dim**2 - return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode, - data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) + return Matmul_AB_2D( + a, + b, + summa_dim, + out_shape, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class _Classifier2D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -132,10 +144,21 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None -def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, out_shape: Tuple[int, ...], - row_rank: int, col_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: +def classifier_2d( + A: Tensor, + B: Tensor, + bias: Optional[Tensor], + summa_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: r"""2D parallel classifier. Args: @@ -157,9 +180,21 @@ def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Classifier2D.apply(A, B, bias, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, - col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, - tensor_parallel_size) + return _Classifier2D.apply( + A, + B, + bias, + summa_dim, + out_shape, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class Matmul_AB_2D(torch.autograd.Function): @@ -205,8 +240,7 @@ def forward( # B: [h / q, s / q] # C: [b / q, s, s / q] -> [(b * s) / q, s / q] - assert A.shape[-1] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) + assert A.shape[-1] == B.shape[-2], "Invalid shapes: A={}, B={} for AB.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -226,10 +260,16 @@ def forward( row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = ( + summa_dim * row_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_b = ( + col_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opa = [None] * 2 opb = [None] * 2 @@ -278,14 +318,34 @@ def forward( def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_ABT_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) + A_grad = Matmul_ABT_2D.apply( + output_grad, + B, + ctx.summa_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2D.apply( + A, + output_grad, + ctx.summa_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -329,9 +389,7 @@ def forward( pipeline_parallel_size: int, tensor_parallel_size: int, ) -> Tensor: - - assert A.shape[-1] == B.shape[-1], \ - 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) + assert A.shape[-1] == B.shape[-1], "Invalid shapes: A={}, B={} for ABT.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -351,10 +409,16 @@ def forward( row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_b = ( + col_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + summa_dim * row_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opb = [None] * 2 opr = [None] * 2 @@ -412,14 +476,34 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_AB_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2D.apply(output_grad, A, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) + A_grad = Matmul_AB_2D.apply( + output_grad, + B, + ctx.summa_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2D.apply( + output_grad, + A, + ctx.summa_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None @@ -462,9 +546,7 @@ def forward( pipeline_parallel_size: int, tensor_parallel_size: int, ) -> Tensor: - - assert A.shape[-2] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) + assert A.shape[-2] == B.shape[-2], "Invalid shapes: A={}, B={} for ATB.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -484,10 +566,16 @@ def forward( row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = ( + summa_dim * row_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + col_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opa = [None] * 2 opr = [None] * 2 @@ -545,19 +633,38 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_ABT_2D.apply(B, output_grad, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) - B_grad = Matmul_AB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank, - ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size, - ctx.tensor_parallel_size) + A_grad = Matmul_ABT_2D.apply( + B, + output_grad, + ctx.summa_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_AB_2D.apply( + A, + output_grad, + ctx.summa_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None class _Add_Bias_2D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -608,10 +715,20 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: return output_grad, grad, None, None, None, None, None, None, None, None, None, None -def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, row_rank: int, col_rank: int, - row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, skip_bias_add: bool, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: +def add_bias_2d( + input_: Tensor, + bias: Tensor, + output_size_per_partition: int, + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: r"""Matrix add bias: :math:`C = A + b`. Args: @@ -633,17 +750,34 @@ def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, ro The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Add_Bias_2D.apply(input_, bias, output_size_per_partition, row_rank, col_rank, row_parallel_mode, - col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, - pipeline_parallel_size, tensor_parallel_size) + return _Add_Bias_2D.apply( + input_, + bias, + output_size_per_partition, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + skip_bias_add, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class _Layernorm_2D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode) -> Tensor: + def forward( + ctx: Any, + input_: Tensor, + E_x: Tensor, + Var_x: Tensor, + hidden_size: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + ) -> Tensor: input_ = input_ - E_x # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) ctx.normalized_shape = hidden_size @@ -657,7 +791,7 @@ def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: i @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: row_parallel_mode = ctx.row_parallel_mode - col_parallel_mode = ctx.col_parallel_mode + ctx.col_parallel_mode x, Var_x = ctx.saved_tensors # in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True) @@ -676,8 +810,14 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: return input_grad, None, None, None, None, None -def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode) -> Tensor: +def layernorm_2d( + input_: Tensor, + E_x: Tensor, + Var_x: Tensor, + hidden_size: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, +) -> Tensor: r"""Layernorm. Args: @@ -696,7 +836,6 @@ def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, r class _AllGatherTensor2D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: @@ -744,15 +883,14 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor: if world_size <= 1: return input_ - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + assert dim_size % world_size == 0, f"The batch size ({dim_size}) is not a multiple of 2D size ({world_size})." - return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), - dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL), dim=dim)[ + gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) + ].contiguous() class _ReduceTensor2D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, parallel_mode): return all_reduce(input_, parallel_mode) @@ -777,7 +915,6 @@ def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: class _ReduceScatterTensor2D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -803,14 +940,12 @@ def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMo """ dim_size = tensor.size(dim) world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2D size ({world_size}).' + assert dim_size % world_size == 0, f"The batch size ({dim_size}) is not a multiple of 2D size ({world_size})." return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode) class _ReduceByBatch2D(torch.autograd.Function): - @staticmethod def symbolic(graph, input_, reduce_mean: bool = False): output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) diff --git a/colossalai/legacy/nn/layer/parallel_2d/_utils.py b/colossalai/legacy/nn/layer/parallel_2d/_utils.py index 87ba1bf69691..fe18af26f88f 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_2d/_utils.py @@ -6,15 +6,17 @@ def get_summa_dim_from_env() -> int: try: summa_dim = env.summa_dim - assert summa_dim > 0, 'SUMMA_DIM must be larger than zero' + assert summa_dim > 0, "SUMMA_DIM must be larger than zero" return summa_dim - except KeyError as e: - raise EnvironmentError('SUMMA_DIM is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer') + except KeyError: + raise EnvironmentError( + "SUMMA_DIM is not found in the current environment, " + "please make sure that you have used the correct process group initializer" + ) def assert_summa_initialization(): - assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2D_ROW), \ - 'Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer' + assert gpc.is_initialized(ParallelMode.PARALLEL_2D_COL) and gpc.is_initialized( + ParallelMode.PARALLEL_2D_ROW + ), "Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer" diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py index 893bc74b57d9..3b2e032e5127 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -55,14 +55,16 @@ class Linear2D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features @@ -80,15 +82,16 @@ def __init__(self, self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter( - torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) + ) # create bias, shape: [h/q] if bias: self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) # initialize parameters with seed(ParallelMode.TENSOR): @@ -108,8 +111,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -126,34 +129,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -162,14 +153,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -177,14 +162,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -196,22 +175,53 @@ def forward(self, x: Tensor) -> Tensor: # output: [m/q, n/q, h/q] out_shape = x.shape[:-1] + (self.hidden_size_per_partition,) - output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + output = Matmul_AB_2D.apply( + x, + self.weight, + self.summa_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) if self.bias is not None: if self.skip_bias_add: - bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + bias = add_bias_2d( + None, + self.bias, + self.hidden_size_per_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output, bias else: - output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = add_bias_2d( + output, + self.bias, + self.hidden_size_per_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output else: return output @@ -249,7 +259,7 @@ def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=N self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) # create parameters - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -266,8 +276,8 @@ def _set_tensor_parallel_attributes(self): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -283,34 +293,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -319,14 +317,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -334,14 +326,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -349,29 +335,51 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): - E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) E_x /= self.normalized_shape # Var_x in the block below is the sum of input^2 - Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) Var_x /= self.normalized_shape - Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) - output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, - ParallelMode.PARALLEL_2D_COL) - scale = add_bias_2d(None, self.weight, self.partitioned_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + output = layernorm_2d( + x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL + ) + scale = add_bias_2d( + None, + self.weight, + self.partitioned_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) if self.bias is not None: - bias = add_bias_2d(None, self.bias, self.partitioned_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + bias = add_bias_2d( + None, + self.bias, + self.partitioned_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) output = torch.addcmul(bias, scale, output) else: output = torch.mul(scale, output) @@ -400,16 +408,18 @@ class PatchEmbedding2D(ParallelLayer): `init `_. """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -426,17 +436,22 @@ def __init__(self, with seed(ParallelMode.TENSOR): self.weight = Parameter( - torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype, + ) + ) self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) self.pos_embed = Parameter( - torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.zeros( + (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + ) + ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attribute() @@ -457,10 +472,10 @@ def reset_parameters(self, weight_initializer, bias_initializer, position_embed_ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -484,67 +499,34 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - local_state = OrderedDict({ - weight_key: self.weight, - bias_key: self.bias, - cls_token_key: self.cls_token, - pos_embed_key: self.pos_embed - }) + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + local_state = OrderedDict( + {weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed} + ) # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -552,18 +534,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -573,15 +545,16 @@ def forward(self, input_: Tensor) -> Tensor: input_ = split_batch_2d(input_) B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL) bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL) output = F.conv2d(input_, weight, bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL) pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL) @@ -623,14 +596,16 @@ class Embedding2D(ParallelLayer): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() assert_summa_initialization() @@ -644,7 +619,8 @@ def __init__(self, self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -665,7 +641,7 @@ def _fill_padding_idx_with_zero(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -691,7 +667,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in column groups @@ -754,14 +730,16 @@ class VocabParallelEmbedding2D(ParallelLayer): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -778,9 +756,12 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype, + ) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -796,14 +777,17 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -829,7 +813,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in column groups @@ -857,10 +841,11 @@ def forward(self, input_: Tensor) -> Tensor: masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL) return output @@ -884,14 +869,16 @@ class Classifier2D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -908,7 +895,8 @@ def __init__(self, self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + ) self.has_weight = True if bias: self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) @@ -938,8 +926,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -957,34 +945,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight @@ -995,14 +971,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) # gather in row groups @@ -1010,14 +980,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -1026,9 +990,21 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): def forward(self, input_: Tensor) -> Tensor: out_shape = input_.shape[:-1] + (self.num_classes,) - return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank, - self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size) + return classifier_2d( + input_, + self.weight, + self.bias, + self.summa_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) @LAYERS.register_module @@ -1050,14 +1026,16 @@ class VocabParallelClassifier2D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features @@ -1074,13 +1052,14 @@ def __init__(self, self.output_size_per_partition = divide(num_classes, self.summa_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False else: self.weight = Parameter( - torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs) + ) self.has_weight = True # create bias, shape: [h/q] if bias: @@ -1109,8 +1088,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -1128,34 +1107,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight @@ -1166,14 +1133,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -1181,14 +1142,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -1200,14 +1155,34 @@ def forward(self, x: Tensor) -> Tensor: # output: [m/q, n/q, h/q] out_shape = x.shape[:-1] + (self.output_size_per_partition,) - output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = Matmul_ABT_2D.apply( + x, + self.weight, + self.summa_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) if self.bias is not None: - output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = add_bias_2d( + output, + self.bias, + self.output_size_per_partition, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py index 23e47e6ed06b..46b4d3f3b782 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/__init__.py @@ -10,6 +10,13 @@ ) __all__ = [ - 'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', - 'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D' + "split_batch_2p5d", + "reduce_by_batch_2p5d", + "Linear2p5D", + "LayerNorm2p5D", + "Classifier2p5D", + "PatchEmbedding2p5D", + "Embedding2p5D", + "VocabParallelClassifier2p5D", + "VocabParallelEmbedding2p5D", ] diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py index 1226162ae399..50900c135cab 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_operation.py @@ -24,7 +24,6 @@ def get_parallel_rank(parallel_mode: ParallelMode): class _Classifier2p5D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -98,10 +97,21 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None -def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: Tuple[int, - ...], row_rank: int, col_rank: int, - row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, data_parallel_rank: int, - pipeline_parallel_rank: int, pipeline_parallel_size: int, tensor_parallel_size: int) -> Tensor: +def classifier_2p5d( + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: r"""Classifier. Args: @@ -123,9 +133,21 @@ def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: T The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Classifier2p5D.apply(A, B, bias, tesseract_dim, out_shape, row_rank, col_rank, row_parallel_mode, - col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, - tensor_parallel_size) + return _Classifier2p5D.apply( + A, + B, + bias, + tesseract_dim, + out_shape, + row_rank, + col_rank, + row_parallel_mode, + col_parallel_mode, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class Matmul_AB_2p5D(torch.autograd.Function): @@ -153,16 +175,27 @@ class Matmul_AB_2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, - col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + dep_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: # A: [b / dq, s, h / q] -> [(b * s) / dq, h / q] # B: [h / dq, s / q] # C: [b / dq, s, s / q] -> [(b * s) / dq, s / q] - assert A.shape[-1] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape) + assert A.shape[-1] == B.shape[-2], "Invalid shapes: A={}, B={} for AB.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -182,14 +215,18 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = \ - tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_b = \ - col_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = ( + tesseract_dim * row_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_b = ( + col_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opa = [None] * 2 opb = [None] * 2 @@ -205,10 +242,9 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple A_list[1 - cur].copy_(A) opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True) B_list[1 - cur].copy_(B) - opb[1 - cur] = dist.broadcast(B_list[1 - cur], - src=src_b + tesseract_dim, - group=col_group, - async_op=True) + opb[1 - cur] = dist.broadcast( + B_list[1 - cur], src=src_b + tesseract_dim, group=col_group, async_op=True + ) if opa[cur] is not None: opa[cur].wait() @@ -242,14 +278,36 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_ABT_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + A_grad = Matmul_ABT_2p5D.apply( + output_grad, + B, + ctx.tesseract_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2p5D.apply( + A, + output_grad, + ctx.tesseract_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None @@ -278,13 +336,23 @@ class Matmul_ABT_2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, - col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: - - assert A.shape[-1] == B.shape[-1], \ - 'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + dep_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: + assert A.shape[-1] == B.shape[-1], "Invalid shapes: A={}, B={} for ABT.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -304,14 +372,18 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_b = \ - col_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = \ - tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_b = ( + col_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + tesseract_dim * row_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opb = [None] * 2 opr = [None] * 2 @@ -323,10 +395,9 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple for i in range(tesseract_dim): if i != tesseract_dim - 1: B_list[1 - cur].copy_(B) - opb[1 - cur] = dist.broadcast(B_list[1 - cur], - src=src_b + tesseract_dim, - group=col_group, - async_op=True) + opb[1 - cur] = dist.broadcast( + B_list[1 - cur], src=src_b + tesseract_dim, group=col_group, async_op=True + ) if opr[cur] is not None: opr[cur].wait() @@ -372,14 +443,36 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_AB_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - B_grad = Matmul_ATB_2p5D.apply(output_grad, A, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + A_grad = Matmul_AB_2p5D.apply( + output_grad, + B, + ctx.tesseract_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_ATB_2p5D.apply( + output_grad, + A, + ctx.tesseract_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None @@ -408,13 +501,23 @@ class Matmul_ATB_2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, - col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int): - - assert A.shape[-2] == B.shape[-2], \ - 'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + dep_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ): + assert A.shape[-2] == B.shape[-2], "Invalid shapes: A={}, B={} for ATB.".format(A.shape, B.shape) if ctx: ctx.save_for_backward(A, B) @@ -434,14 +537,18 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = \ - tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = \ - col_rank + tesseract_dim ** 2 * dep_rank + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = ( + tesseract_dim * row_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) + src_c = ( + col_rank + + tesseract_dim**2 * dep_rank + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) opa = [None] * 2 opr = [None] * 2 @@ -499,33 +606,68 @@ def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: A, B = ctx.saved_tensors with torch.no_grad(): - A_grad = Matmul_ABT_2p5D.apply(B, output_grad, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) - B_grad = Matmul_AB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank, - ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode, - ctx.data_parallel_rank, ctx.pipeline_parallel_rank, - ctx.pipeline_parallel_size, ctx.tensor_parallel_size) + A_grad = Matmul_ABT_2p5D.apply( + B, + output_grad, + ctx.tesseract_dim, + ctx.A_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) + B_grad = Matmul_AB_2p5D.apply( + A, + output_grad, + ctx.tesseract_dim, + ctx.B_shape, + ctx.row_rank, + ctx.col_rank, + ctx.dep_rank, + ctx.row_parallel_mode, + ctx.col_parallel_mode, + ctx.data_parallel_rank, + ctx.pipeline_parallel_rank, + ctx.pipeline_parallel_size, + ctx.tensor_parallel_size, + ) return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None class _Add_Bias_2p5D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, - row_rank: int, col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: + def forward( + ctx: Any, + input: Tensor, + bias: Tensor, + output_size_per_partition: int, + tesseract_dim: int, + row_rank: int, + col_rank: int, + dep_rank: int, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + ) -> Tensor: if row_rank == 0: bias_temp = bias.clone() else: bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) - src_rank = \ - col_rank + dep_rank * tesseract_dim ** 2 + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_rank = ( + col_rank + + dep_rank * tesseract_dim**2 + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) ctx.row_rank = row_rank @@ -559,43 +701,120 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: tensor_parallel_size = ctx.tensor_parallel_size if ctx.bias: - dst_rank = \ - col_rank + dep_rank * (tesseract_dim ** 2) + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + dst_rank = ( + col_rank + + dep_rank * (tesseract_dim**2) + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) if row_rank == 0: - return \ - None, output_grad, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None + return ( + None, + output_grad, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) else: grad_tmp = torch.zeros_like(output_grad) - return \ - None, grad_tmp, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None + return ( + None, + grad_tmp, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) else: reduce_dim = tuple(range(output_grad.ndim - 1)) reduce = torch.sum(output_grad, dim=reduce_dim) - dst_rank = \ - col_rank + dep_rank * (tesseract_dim ** 2) + \ - data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + dst_rank = ( + col_rank + + dep_rank * (tesseract_dim**2) + + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + + pipeline_parallel_rank * tensor_parallel_size + ) dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) if row_rank == 0: - return \ - output_grad, reduce, None, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None + return ( + output_grad, + reduce, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) else: reduce_tmp = torch.zeros_like(reduce) - return \ - output_grad, reduce_tmp, None, None, None, None, None, None, \ - None, None, None, None, None, None, None, None, None - - -def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int, - col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool, - data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int, - tensor_parallel_size: int) -> Tensor: + return ( + output_grad, + reduce_tmp, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def add_bias_2p5d( + input: Tensor, + bias: Tensor, + output_size_per_partition: int, + tesseract_dim: int, + row_rank: int, + col_rank: int, + dep_rank: int, + col_parallel_mode: ParallelMode, + skip_bias_add: bool, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, +) -> Tensor: r"""Matrix add bias: :math:`C = A + b`. Args: @@ -618,9 +837,21 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found in `parallel_mode `_ """ - return _Add_Bias_2p5D.apply(input, bias, output_size_per_partition, tesseract_dim, row_rank, col_rank, dep_rank, - col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank, - pipeline_parallel_size, tensor_parallel_size) + return _Add_Bias_2p5D.apply( + input, + bias, + output_size_per_partition, + tesseract_dim, + row_rank, + col_rank, + dep_rank, + col_parallel_mode, + skip_bias_add, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) class _Layernorm2p5D(torch.autograd.Function): @@ -640,8 +871,9 @@ class _Layernorm2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, - row_parallel_mode: ParallelMode) -> Tensor: + def forward( + ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode + ) -> Tensor: input = input - E_x # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) ctx.hidden_size = hidden_size @@ -673,8 +905,9 @@ def backward(ctx, output_grad): return input_grad, None, None, None, None, None, None -def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, - row_parallel_mode: ParallelMode) -> Tensor: +def layernorm_2p5d( + input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode +) -> Tensor: r"""Layernorm. Args: @@ -692,7 +925,6 @@ def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, class _AllGatherTensor2p5D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor: @@ -753,9 +985,9 @@ def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: Par def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: grad_shape = (ctx.batch_size,) + output_grad.shape[1:] grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) - dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)), - output_grad.contiguous(), - group=gpc.get_group(ctx.para_mode)) + dist.all_gather( + list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode) + ) return grad, None, None @@ -775,15 +1007,16 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor: if world_size <= 1: return input_ - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." - return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), - dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() + return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), dim=dim)[ + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) + ].contiguous() class _ReduceTensor2p5D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, parallel_mode): return all_reduce(input_, parallel_mode) @@ -808,7 +1041,6 @@ def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: class _ReduceScatterTensor2p5D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -834,14 +1066,14 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel """ dim_size = input_.size(dim) world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).' + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode) class _RreduceByBatch2p5D(torch.autograd.Function): - @staticmethod def symbolic(graph, input_, reduce_mean: bool = False): output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py index 69a350a977ac..8cda15aed2a7 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/_utils.py @@ -7,19 +7,24 @@ def get_tesseract_dim_dep_from_env(): try: tesseract_dim = env.tesseract_dim tesseract_dep = env.tesseract_dep - assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero' - assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero' + assert tesseract_dim > 0, "TESSERACT_DIM must be larger than zero" + assert tesseract_dep > 0, "TESSERACT_DEP must be larger than zero" return tesseract_dim, tesseract_dep - except KeyError as e: - raise EnvironmentError('TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer') + except KeyError: + raise EnvironmentError( + "TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, " + "please make sure that you have used the correct process group initializer" + ) def assert_tesseract_initialization(): - assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \ - gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \ - 'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \ - 'must be initialized by the process group initializer' + assert ( + gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) + and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) + and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) + and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ) + ), ( + "Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ " + "must be initialized by the process group initializer" + ) diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py index b4aa9f16ddf0..fc2e35f36cbc 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -56,14 +56,16 @@ class Linear2p5D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features @@ -82,15 +84,16 @@ def __init__(self, self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter( - torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)) + torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) + ) # create bias, shape: [h/q] if bias: self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) # initialize parameters with seed(ParallelMode.TENSOR): @@ -110,8 +113,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -124,43 +127,33 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # broadcast in dep groups - if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 and \ - gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: + if ( + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 + and gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0 + ): broadcast_state_dict(local_state, ParallelMode.PARALLEL_2P5D_DEP) # partition in column groups if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0: local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in row groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0: - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -169,14 +162,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in column groups @@ -184,14 +171,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -221,16 +202,38 @@ def forward(self, x: Tensor) -> Tensor: if self.bias is not None: if self.skip_bias_add: - bias = add_bias_2p5d(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + bias = add_bias_2p5d( + None, + self.bias, + self.hidden_size_per_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output, bias else: - output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, - self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, - False, self.data_parallel_rank, self.pipeline_parallel_rank, - self.pipeline_parallel_size, self.tensor_parallel_size) + output = add_bias_2p5d( + output, + self.bias, + self.hidden_size_per_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output else: return output @@ -266,10 +269,10 @@ def __init__(self, normalized_shape: int, eps: float = 1e-05, bias=True, dtype=N self.tesseract_dim, _ = get_tesseract_dim_dep_from_env() # partitioning dimension - self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * + self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * # create parameters - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) if bias: @@ -286,8 +289,8 @@ def _set_tensor_parallel_attribute(self): def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -303,34 +306,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -339,14 +330,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -354,14 +339,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -369,29 +348,51 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): def forward(self, x: Tensor) -> Tensor: with torch.no_grad(): - E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] + E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) E_x /= self.normalized_shape # Var_x in the block below is the sum of input^2 - Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] + Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1] torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) Var_x /= self.normalized_shape - Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] + Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1] # this time 1/sqrt(Var_x + epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) - scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + scale = add_bias_2p5d( + None, + self.weight, + self.partitioned_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) if self.bias is not None: - bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + bias = add_bias_2p5d( + None, + self.bias, + self.partitioned_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + True, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) output = torch.addcmul(bias, scale, output) else: output = torch.mul(scale, output) @@ -420,16 +421,18 @@ class PatchEmbedding2p5D(ParallelLayer): `init `_. """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -446,17 +449,22 @@ def __init__(self, with seed(ParallelMode.TENSOR): self.weight = Parameter( - torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.embed_size_per_partition, in_chans, *self.patch_size), + device=get_current_device(), + dtype=dtype, + ) + ) self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.cls_token = Parameter( - torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) self.pos_embed = Parameter( - torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.zeros( + (1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype + ) + ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attribute() @@ -477,10 +485,10 @@ def reset_parameters(self, weight_initializer, bias_initializer, position_embed_ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -504,67 +512,34 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - local_state = OrderedDict({ - weight_key: self.weight, - bias_key: self.bias, - cls_token_key: self.cls_token, - pos_embed_key: self.pos_embed - }) + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + local_state = OrderedDict( + {weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed} + ) # gather in column groups local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) # gather in row groups @@ -572,18 +547,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -593,15 +558,16 @@ def forward(self, input_: Tensor) -> Tensor: input_ = split_batch_2p5d(input_, 0) B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL) bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL) output = F.conv2d(input_, weight, bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL) pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL) @@ -643,14 +609,16 @@ class Embedding2p5D(ParallelLayer): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() assert_tesseract_initialization() @@ -664,7 +632,8 @@ def __init__(self, self.embed_kwargs = kwargs self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -685,7 +654,7 @@ def _fill_padding_idx_with_zero(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -711,7 +680,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in column groups @@ -775,14 +744,16 @@ class VocabParallelEmbedding2p5D(ParallelLayer): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -799,9 +770,12 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype, + ) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -817,14 +791,13 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.vocab_start_index <= self.padding_idx < self.vocab_end_index: + if self.padding_idx is not None and self.vocab_start_index <= self.padding_idx < self.vocab_end_index: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -850,7 +823,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in column groups @@ -880,11 +853,12 @@ def forward(self, input_: Tensor) -> Tensor: masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) # Mask the output embedding. - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL) return output @@ -909,14 +883,16 @@ class Classifier2p5D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -934,7 +910,8 @@ def __init__(self, self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)) + torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) + ) self.has_weight = True if bias: self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) @@ -964,8 +941,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -983,34 +960,22 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight @@ -1021,14 +986,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) # gather in row groups @@ -1036,14 +995,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -1052,10 +1005,21 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): def forward(self, input_: Tensor) -> Tensor: out_shape = input_.shape[:-1] + (self.num_classes,) - return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, - self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + return classifier_2p5d( + input_, + self.weight, + self.bias, + self.tesseract_dim, + out_shape, + self.row_rank, + self.col_rank, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) @LAYERS.register_module @@ -1077,14 +1041,16 @@ class VocabParallelClassifier2p5D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features @@ -1102,13 +1068,14 @@ def __init__(self, self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) # create weight, shape: [k/q, h/q] - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} if weight is not None: self.weight = weight self.has_weight = False else: self.weight = Parameter( - torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs)) + torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs) + ) self.has_weight = True # create bias, shape: [h/q] if bias: @@ -1137,8 +1104,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -1156,27 +1123,15 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_ROW, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in column groups local_state = partition_tensor_parallel_state_dict( local_state, ParallelMode.PARALLEL_2P5D_COL, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) @@ -1203,8 +1158,19 @@ def forward(self, x: Tensor) -> Tensor: ) if self.bias is not None: - output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank, - self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, False, - self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, - self.tensor_parallel_size) + output = add_bias_2p5d( + output, + self.bias, + self.hidden_size_per_partition, + self.tesseract_dim, + self.row_rank, + self.col_rank, + self.dep_rank, + ParallelMode.PARALLEL_2P5D_COL, + False, + self.data_parallel_rank, + self.pipeline_parallel_rank, + self.pipeline_parallel_size, + self.tensor_parallel_size, + ) return output diff --git a/colossalai/legacy/nn/layer/parallel_3d/__init__.py b/colossalai/legacy/nn/layer/parallel_3d/__init__.py index 17fe8403c585..5d38f6a56874 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_3d/__init__.py @@ -10,6 +10,14 @@ ) __all__ = [ - 'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', - 'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D' + "reduce_by_batch_3d", + "split_tensor_3d", + "split_batch_3d", + "Linear3D", + "LayerNorm3D", + "PatchEmbedding3D", + "Classifier3D", + "Embedding3D", + "VocabParallelEmbedding3D", + "VocabParallelClassifier3D", ] diff --git a/colossalai/legacy/nn/layer/parallel_3d/_operation.py b/colossalai/legacy/nn/layer/parallel_3d/_operation.py index c6374efb7124..fe42d8e28111 100755 --- a/colossalai/legacy/nn/layer/parallel_3d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_3d/_operation.py @@ -16,7 +16,6 @@ class _Linear3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -52,7 +51,8 @@ def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) weight_grad = torch.matmul( - input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]) + ) weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) @@ -92,7 +92,6 @@ def linear_3d( class _Classifier3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -131,7 +130,8 @@ def forward( def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_, weight = ctx.saved_tensors weight_grad = torch.matmul( - output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1])) + output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]) + ) weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode) if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode): weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True) @@ -187,7 +187,6 @@ def classifier_3d( class _VocabParallelClassifier3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( @@ -230,7 +229,8 @@ def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True) weight_grad = torch.matmul( - input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) + input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]) + ) weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) @@ -296,7 +296,7 @@ def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): # dbias, dweight = grad, grad * mu / sigma dz = grad * weight dmu = dz / sigma - dvar = dz * mu * (-0.5) * sigma**(-3) + dvar = dz * mu * (-0.5) * sigma ** (-3) dmean = -dmu dvar = torch.sum(dvar, -1, keepdim=True) dmean = torch.sum(dmean, -1, keepdim=True) @@ -305,7 +305,6 @@ def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): class _Layernorm3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward( @@ -415,20 +414,24 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te """ dim_size = tensor.size(dim) world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' + assert dim_size % world_size == 0, ( + f"The dimension {dim} to split, size ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) if tensor.size(dim) <= 1: return tensor - output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), - dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous() + output = torch.chunk(tensor, gpc.get_world_size(parallel_mode), dim=dim)[ + gpc.get_local_rank(parallel_mode) + ].contiguous() return output -def split_batch_3d(input_: Tensor, - dim: int = 0, - input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, - weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: +def split_batch_3d( + input_: Tensor, + dim: int = 0, + input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, + weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT, +) -> Tensor: r"""Splits 3D tensor in batch. Args: @@ -456,7 +459,6 @@ def split_batch_3d(input_: Tensor, class _ReduceTensor3D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, parallel_mode): return all_reduce(input_, parallel_mode) @@ -481,7 +483,6 @@ def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: class _AllGatherTensor3D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -511,7 +512,6 @@ def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) class _ReduceScatterTensor3D(torch.autograd.Function): - @staticmethod def forward(ctx, input_, dim, parallel_mode): ctx.dim = dim @@ -538,21 +538,23 @@ def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMo """ dim_size = tensor.size(dim) world_size = gpc.get_world_size(parallel_mode) - assert dim_size % world_size == 0, \ - f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size}).' + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of square of 3D depth ({world_size})." return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode) class _ReduceByBatch3D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, - input_: Tensor, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - reduce_mean: bool = False) -> Tensor: + def forward( + ctx, + input_: Tensor, + input_parallel_mode: ParallelMode, + weight_parallel_mode: ParallelMode, + reduce_mean: bool = False, + ) -> Tensor: output = all_reduce(input_, input_parallel_mode) output = all_reduce(output, weight_parallel_mode) ctx.reduce_mean = reduce_mean @@ -571,10 +573,9 @@ def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]: return output_grad, None, None, None -def reduce_by_batch_3d(tensor: Tensor, - input_parallel_mode: ParallelMode, - weight_parallel_mode: ParallelMode, - reduce_mean: bool = False) -> Tensor: +def reduce_by_batch_3d( + tensor: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, reduce_mean: bool = False +) -> Tensor: r"""All-reduce the input from the model parallel region. Args: diff --git a/colossalai/legacy/nn/layer/parallel_3d/_utils.py b/colossalai/legacy/nn/layer/parallel_3d/_utils.py index cb300c2a9684..8c967da74e67 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/_utils.py +++ b/colossalai/legacy/nn/layer/parallel_3d/_utils.py @@ -18,17 +18,24 @@ def get_depth_from_env() -> int: try: depth = env.depth_3d - assert depth > 0, 'DEPTH must be greater than zero' + assert depth > 0, "DEPTH must be greater than zero" return depth - except KeyError as e: - raise EnvironmentError('DEPTH is not found in the current environment, ' - 'please make sure that you have used the correct process group initializer') + except KeyError: + raise EnvironmentError( + "DEPTH is not found in the current environment, " + "please make sure that you have used the correct process group initializer" + ) def get_parallel_mode_from_env(group): - assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \ - f'{group} is not valid for 3D tensor parallelism.' + assert group in [ + INPUT_GROUP_3D, + WEIGHT_GROUP_3D, + OUTPUT_GROUP_3D, + INPUT_X_WEIGHT_3D, + OUTPUT_X_WEIGHT_3D, + ], f"{group} is not valid for 3D tensor parallelism." return getattr(env, group) @@ -44,12 +51,10 @@ def dbg_check_shape(tensor: Tensor, shape: tuple): rank = gpc.get_global_rank() if rank == 0: print(tensor.shape) - assert tensor.shape == shape, \ - '{} does not match {}'.format(tensor.shape, shape) + assert tensor.shape == shape, "{} does not match {}".format(tensor.shape, shape) class AsyncGradientBucket(object): - def __init__(self): self.bucket = OrderedDict() diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py index d6aaa427b9e6..196679994197 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -59,7 +59,6 @@ class LayerNorm3D(ParallelLayer): """ def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=None): - super().__init__() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -70,10 +69,12 @@ def __init__(self, normalized_shape: int, eps: float = 1e-12, bias=True, dtype=N self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.weight = Parameter( - torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) + torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + ) if bias: self.bias = Parameter( - torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)) + torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) + ) else: self.bias = None self.variance_epsilon = eps @@ -94,8 +95,8 @@ def reset_parameters(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -107,15 +108,11 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, + dims={weight_key: 0, bias_key: 0}, partition_states={ weight_key: True, bias_key: True, @@ -130,26 +127,19 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -185,14 +175,16 @@ class Linear3D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.out_features = out_features @@ -207,13 +199,17 @@ def __init__(self, self.bias_features_per_partition = divide(out_features, self.depth) self.weight = Parameter( - torch.empty(self.in_features_per_partition, - self.out_features_per_partition, - device=get_current_device(), - dtype=dtype)) + torch.empty( + self.in_features_per_partition, + self.out_features_per_partition, + device=get_current_device(), + dtype=dtype, + ) + ) if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -239,15 +235,17 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, - gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], - self.output_x_weight_parallel_mode) + broadcast( + self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode, + ) self.bias.register_hook(self._sync_grad_hook) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -260,53 +258,34 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.input_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in weight groups local_state = partition_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -315,14 +294,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) # gather in input groups @@ -330,30 +303,17 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, self.input_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -396,14 +356,16 @@ class Classifier3D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -418,7 +380,8 @@ def __init__(self, self.has_weight = False else: self.weight = Parameter( - torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)) + torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype) + ) self.has_weight = True if bias: self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) @@ -449,8 +412,8 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -464,19 +427,12 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # broadcast in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: @@ -487,8 +443,8 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict() if self.has_weight: local_state[weight_key] = self.weight @@ -496,19 +452,12 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state[bias_key] = self.bias # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -544,14 +493,16 @@ class VocabParallelClassifier3D(ParallelLayer): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -569,14 +520,18 @@ def __init__(self, self.has_weight = False else: self.weight = Parameter( - torch.empty(self.out_features_per_partition, - self.in_features_per_partition, - device=get_current_device(), - dtype=dtype)) + torch.empty( + self.out_features_per_partition, + self.in_features_per_partition, + device=get_current_device(), + dtype=dtype, + ) + ) self.has_weight = True if bias: self.bias = Parameter( - torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)) + torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) + ) else: self.bias = None @@ -602,15 +557,17 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, - gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], - self.output_x_weight_parallel_mode) + broadcast( + self.bias, + gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0], + self.output_x_weight_parallel_mode, + ) register_async_grad_hook(self.bias) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight if self.has_weight: @@ -624,53 +581,34 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[bias_key] = bias # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) # partition in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.input_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, ) # partition in weight groups local_state = partition_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, ) super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' + weight_key = prefix + "weight" + bias_key = prefix + "bias" local_state = OrderedDict({weight_key: self.weight}) if self.bias is not None: local_state[bias_key] = self.bias @@ -679,14 +617,8 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, self.weight_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) # gather in input groups @@ -694,30 +626,17 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): local_state = gather_tensor_parallel_state_dict( local_state, self.input_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, + dims={weight_key: 0, bias_key: 0}, + partition_states={weight_key: True, bias_key: True}, keep_vars=keep_vars, ) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, + dims={weight_key: -1, bias_key: 0}, + partition_states={weight_key: True, bias_key: False}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -756,16 +675,18 @@ class PatchEmbedding3D(ParallelLayer): `init `_. """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): super().__init__() self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -783,15 +704,18 @@ def __init__(self, self.flatten = flatten self.weight = nn.Parameter( - torch.empty((embed_size_per_partition, in_chans, *self.patch_size), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype + ) + ) self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.cls_token = nn.Parameter( - torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)) + torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self._set_tensor_parallel_attributes() @@ -826,10 +750,10 @@ def reset_parameters(self, weight_initializer, bias_initializer, position_embed_ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -849,23 +773,12 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[pos_embed_key] = pos_embed # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, ) # broadcast in input groups if gpc.get_local_rank(self.weight_parallel_mode) == 0: @@ -876,47 +789,33 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - cls_token_key = prefix + 'cls_token' - pos_embed_key = prefix + 'pos_embed' - local_state = OrderedDict({ - weight_key: self.weight, - bias_key: self.bias, - cls_token_key: self.cls_token, - pos_embed_key: self.pos_embed - }) + weight_key = prefix + "weight" + bias_key = prefix + "bias" + cls_token_key = prefix + "cls_token" + pos_embed_key = prefix + "pos_embed" + local_state = OrderedDict( + {weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed} + ) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, - dims={ - weight_key: 0, - bias_key: 0, - cls_token_key: -1, - pos_embed_key: -1 - }, - partition_states={ - weight_key: True, - bias_key: True, - cls_token_key: True, - pos_embed_key: True - }, + dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1}, + partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True}, keep_vars=keep_vars, ) if gpc.get_local_rank(ParallelMode.TENSOR) == 0: destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_3d(input_, - input_parallel_mode=self.input_parallel_mode, - weight_parallel_mode=self.weight_parallel_mode) + input_ = split_batch_3d( + input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode + ) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) @@ -956,14 +855,16 @@ class Embedding3D(ParallelLayer): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.depth = get_depth_from_env() self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) @@ -979,7 +880,8 @@ def __init__(self, self.embed_kwargs = kwargs self.weight = nn.Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -996,8 +898,9 @@ def reset_parameters(self, weight_initializer) -> None: fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() - broadcast(self.weight, - gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode) + broadcast( + self.weight, gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode + ) self.weight.register_hook(self._sync_grad_hook) def _fill_padding_idx_with_zero(self) -> None: @@ -1007,7 +910,7 @@ def _fill_padding_idx_with_zero(self) -> None: def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -1015,8 +918,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[weight_key] = weight # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, @@ -1032,12 +934,11 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, @@ -1049,9 +950,9 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): destination.update(local_state) def forward(self, input_: Tensor) -> Tensor: - input_ = split_batch_3d(input_, - input_parallel_mode=self.input_parallel_mode, - weight_parallel_mode=self.weight_parallel_mode) + input_ = split_batch_3d( + input_, input_parallel_mode=self.input_parallel_mode, weight_parallel_mode=self.weight_parallel_mode + ) output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) return output @@ -1088,14 +989,16 @@ class VocabParallelEmbedding3D(ParallelLayer): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim @@ -1114,9 +1017,12 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition * self.depth self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), - device=get_current_device(), - dtype=dtype)) + torch.empty( + (self.num_embeddings_per_partition, self.embed_dim_per_partition), + device=get_current_device(), + dtype=dtype, + ) + ) self.reset_parameters(weight_initializer) self._set_tensor_parallel_attributes() @@ -1132,14 +1038,17 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() - weight_key = prefix + 'weight' + weight_key = prefix + "weight" if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # weight weight = state_dict.pop(weight_key, None) @@ -1147,8 +1056,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state[weight_key] = weight # partition in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = partition_tensor_parallel_state_dict( local_state, self.output_parallel_mode, @@ -1174,7 +1082,7 @@ def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' + weight_key = prefix + "weight" local_state = OrderedDict({weight_key: self.weight}) # gather in weight groups @@ -1195,8 +1103,7 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): keep_vars=keep_vars, ) # gather in output groups - if gpc.get_local_rank(self.input_parallel_mode) == 0 and \ - gpc.get_local_rank(self.weight_parallel_mode) == 0: + if gpc.get_local_rank(self.input_parallel_mode) == 0 and gpc.get_local_rank(self.weight_parallel_mode) == 0: local_state = gather_tensor_parallel_state_dict( local_state, self.output_parallel_mode, @@ -1218,7 +1125,7 @@ def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode) return output diff --git a/colossalai/legacy/nn/layer/parallel_sequence/__init__.py b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py index d92d66d40a8e..d64aba6bafe4 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/__init__.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/__init__.py @@ -1,4 +1,4 @@ from ._operation import RingAV, RingQK from .layers import TransformerSelfAttentionRing -__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK'] +__all__ = ["TransformerSelfAttentionRing", "RingAV", "RingQK"] diff --git a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py index ea1863f0b474..24d5499e3a5f 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/_operation.py @@ -25,11 +25,13 @@ def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): ctx.sub_seq_length = sub_seq_length # create local segment of attention score - attention_score = torch.empty(batch_size * num_attention_heads, - sub_seq_length, - sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), - dtype=sub_q.dtype, - device=get_current_device()) + attention_score = torch.empty( + batch_size * num_attention_heads, + sub_seq_length, + sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), + dtype=sub_q.dtype, + device=get_current_device(), + ) # compute local QK^T part_a = torch.matmul(sub_q, sub_k.transpose(2, 1)) @@ -51,7 +53,10 @@ def forward(ctx, sub_q, sub_k, batch_size, num_attention_heads, sub_seq_length): @staticmethod @custom_bwd def backward(ctx, grad_output): - sub_q, sub_k, = ctx.saved_tensors + ( + sub_q, + sub_k, + ) = ctx.saved_tensors local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) @@ -59,7 +64,7 @@ def backward(ctx, grad_output): grad_k = torch.matmul(grad_output.transpose(2, 1), sub_q) dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE)) - grad_k = grad_k[:, local_rank * ctx.sub_seq_length:(local_rank + 1) * ctx.sub_seq_length] + grad_k = grad_k[:, local_rank * ctx.sub_seq_length : (local_rank + 1) * ctx.sub_seq_length] grad_k /= local_world_size # calculate gradient for sub_q @@ -96,11 +101,13 @@ def forward(ctx, attention_score, sub_v, batch_size, num_attention_heads, attent local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length) - sub_attention_result = torch.zeros(batch_size * num_attention_heads, - sub_seq_length, - attention_head_size, - device=get_current_device(), - dtype=attention_score.dtype) + sub_attention_result = torch.zeros( + batch_size * num_attention_heads, + sub_seq_length, + attention_head_size, + device=get_current_device(), + dtype=attention_score.dtype, + ) # save tensors for backward ctx.save_for_backward(attention_score, sub_v) diff --git a/colossalai/legacy/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py index 033c1be962ae..063b0cd8e2b2 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/layers.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py @@ -8,7 +8,6 @@ import torch.nn.functional as F from torch.nn import Parameter -import colossalai from colossalai.kernel import FusedScaleMaskSoftmax from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType from colossalai.legacy.context import seed @@ -33,18 +32,20 @@ class TransformerSelfAttentionRing(nn.Module): """ - def __init__(self, - hidden_size, - num_attention_heads, - attention_dropout, - attention_mask_func, - layer_number, - apply_query_key_layer_scaling: bool = False, - convert_fp16_to_fp32_in_softmax: bool = False, - attn_mask_type=AttnMaskType.padding, - masked_softmax_fusion=True, - fp16=False, - bf16=False): + def __init__( + self, + hidden_size, + num_attention_heads, + attention_dropout, + attention_mask_func, + layer_number, + apply_query_key_layer_scaling: bool = False, + convert_fp16_to_fp32_in_softmax: bool = False, + attn_mask_type=AttnMaskType.padding, + masked_softmax_fusion=True, + fp16=False, + bf16=False, + ): super().__init__() self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax self.apply_query_key_layer_scaling = apply_query_key_layer_scaling @@ -59,8 +60,9 @@ def __init__(self, if self.apply_query_key_layer_scaling: self.convert_fp16_to_fp32_in_softmax = True - assert self.hidden_size % self.num_attention_heads == 0, \ - 'hidden size is not divisible by the number of attention heads' + assert ( + self.hidden_size % self.num_attention_heads == 0 + ), "hidden size is not divisible by the number of attention heads" self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads @@ -79,9 +81,15 @@ def __init__(self, self.coeff = layer_number self.norm_factor *= self.coeff - self.scale_mask_softmax = FusedScaleMaskSoftmax(fp16, bf16, self.attn_mask_type, masked_softmax_fusion, - self.attention_mask_func, self.convert_fp16_to_fp32_in_softmax, - self.coeff) + self.scale_mask_softmax = FusedScaleMaskSoftmax( + fp16, + bf16, + self.attn_mask_type, + masked_softmax_fusion, + self.attention_mask_func, + self.convert_fp16_to_fp32_in_softmax, + self.coeff, + ) self.attention_dropout = nn.Dropout(attention_dropout) @@ -102,21 +110,28 @@ def forward(self, hidden_states, attention_mask): mixed_x_layer = self.query_key_value(hidden_states) # [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size] - new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads, - 3 * self.hidden_size_per_attention_head) + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads, + 3 * self.hidden_size_per_attention_head, + ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # split into query, key and value last_dim = mixed_x_layer.dim() - 1 last_dim_value = mixed_x_layer.size(-1) - assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \ - 'cannot be divided into query, key and value' + assert last_dim_value % 3 == 0, ( + "the last dimension is not a multiple of 3, " "cannot be divided into query, key and value" + ) partition_size = last_dim_value // 3 (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, partition_size, dim=last_dim) # attention scores: [batch_size, num_heads, sub_seq_len, seq_len] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), - key_layer.size(0) * self.world_size) + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0) * self.world_size, + ) # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) @@ -125,11 +140,12 @@ def forward(self, hidden_states, attention_mask): # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len] attention_scores = RingQK.apply( - query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] - key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], + query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] + key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], batch_size, self.num_attention_heads, - sub_seq_length) + sub_seq_length, + ) attention_scores /= self.norm_factor @@ -151,12 +167,18 @@ def forward(self, hidden_states, attention_mask): # # change view [b * num_heads, sub_seq_len, seq_len] attention_probs = attention_probs.view( - attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3)) + attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3) + ) # matmul: [batch_size * num_heads, sub_seq_len, head_size] - context_layer = RingAV.apply(attention_probs, - value_layer.transpose(0, 1).contiguous(), batch_size, self.num_attention_heads, - self.hidden_size_per_attention_head, sub_seq_length) + context_layer = RingAV.apply( + attention_probs, + value_layer.transpose(0, 1).contiguous(), + batch_size, + self.num_attention_heads, + self.hidden_size_per_attention_head, + sub_seq_length, + ) # change view [batch_size, num_heads, sub_seq_len, head_size] context_layer = context_layer.view(*output_size) @@ -165,8 +187,9 @@ def forward(self, hidden_states, attention_mask): context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_attention_head * - self.num_attention_heads,) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_attention_head * self.num_attention_heads, + ) context_layer = context_layer.view(*new_context_layer_shape) output, bias = self.dense(context_layer) @@ -174,11 +197,13 @@ def forward(self, hidden_states, attention_mask): return output, bias def __repr__(self): - return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \ - f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \ - f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \ - f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \ - f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})' + return ( + f"TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, " + f"layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, " + f"attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, " + f"hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, " + f"convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})" + ) class _Linear(nn.Module): @@ -208,10 +233,12 @@ def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): self.output_size = output_size self.skip_bias_add = skip_bias_add - self.weight = Parameter(torch.empty( - self.output_size, - self.input_size, - )) + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size, + ) + ) nn.init.xavier_normal_(self.weight) if bias: @@ -220,7 +247,7 @@ def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, input_): # Matrix multiply. @@ -233,5 +260,7 @@ def forward(self, input_): return output def __repr__(self): - return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ - f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' + return ( + f"Linear(in_features={self.input_size}, out_features={self.output_size}, " + + f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})" + ) diff --git a/colossalai/legacy/nn/layer/utils/__init__.py b/colossalai/legacy/nn/layer/utils/__init__.py index 56e969bfd0bd..4e78b228eb4f 100644 --- a/colossalai/legacy/nn/layer/utils/__init__.py +++ b/colossalai/legacy/nn/layer/utils/__init__.py @@ -10,6 +10,12 @@ ) __all__ = [ - 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size', - 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple' + "CheckpointModule", + "divide", + "ACT2FN", + "set_tensor_parallel_attribute_by_size", + "set_tensor_parallel_attribute_by_partition", + "get_tensor_parallel_mode", + "_ntuple", + "to_2tuple", ] diff --git a/colossalai/legacy/nn/layer/utils/common.py b/colossalai/legacy/nn/layer/utils/common.py index 3148a0bed570..fd6a5b38d60a 100644 --- a/colossalai/legacy/nn/layer/utils/common.py +++ b/colossalai/legacy/nn/layer/utils/common.py @@ -14,7 +14,6 @@ class CheckpointModule(nn.Module): - def __init__(self, checkpoint: bool = True, offload: bool = False): super().__init__() self.checkpoint = checkpoint @@ -22,7 +21,7 @@ def __init__(self, checkpoint: bool = True, offload: bool = False): self._offload = offload def _forward(self, *args, **kwargs): - raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward') + raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward") def forward(self, *args, **kwargs): if self._use_checkpoint: @@ -49,9 +48,8 @@ def divide(numerator, denominator): Returns: int: the result of exact division. """ - assert denominator != 0, 'denominator can not be zero' - assert numerator % denominator == 0, \ - '{} is not divisible by {}'.format(numerator, denominator) + assert denominator != 0, "denominator can not be zero" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) return numerator // denominator @@ -80,7 +78,6 @@ def get_tensor_parallel_mode(): def _ntuple(n): - def parse(x): if isinstance(x, collections.abc.Iterable): return x diff --git a/colossalai/legacy/nn/layer/vanilla/__init__.py b/colossalai/legacy/nn/layer/vanilla/__init__.py index 3d767b8886f5..5785bbef33d7 100644 --- a/colossalai/legacy/nn/layer/vanilla/__init__.py +++ b/colossalai/legacy/nn/layer/vanilla/__init__.py @@ -9,6 +9,11 @@ ) __all__ = [ - "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath", - "VanillaLinear" + "VanillaLayerNorm", + "VanillaPatchEmbedding", + "VanillaClassifier", + "DropPath", + "WrappedDropout", + "WrappedDropPath", + "VanillaLinear", ] diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py index 71ca1d421de6..12965a4a6409 100644 --- a/colossalai/legacy/nn/layer/vanilla/layers.py +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -15,7 +15,7 @@ from ..utils import to_2tuple -def drop_path(x, drop_prob: float = 0., training: bool = False): +def drop_path(x, drop_prob: float = 0.0, training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, @@ -28,12 +28,12 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): drop_prob (float, optional): probability of dropping path, defaults 0.0. training (bool, optional): whether in training progress, defaults False. """ - if drop_prob == 0. or not training: + if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize + random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output @@ -74,8 +74,7 @@ class WrappedDropout(nn.Module): def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): super().__init__() if p < 0 or p > 1: - raise ValueError("dropout probability has to be between 0 and 1, " - "but got {}".format(p)) + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) self.p = p self.inplace = inplace if mode is None: @@ -108,7 +107,7 @@ class WrappedDropPath(nn.Module): in `parallel_mode `_ """ - def __init__(self, p: float = 0., mode=None): + def __init__(self, p: float = 0.0, mode=None): super().__init__() self.p = p self.mode = mode @@ -152,16 +151,18 @@ class VanillaPatchEmbedding(nn.Module): `init `_. """ - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - flatten: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + flatten: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_(), + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -172,11 +173,13 @@ def __init__(self, self.flatten = flatten self.weight = nn.Parameter( - torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) + torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype) + ) self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) self.pos_embed = nn.Parameter( - torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) + torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype) + ) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) @@ -188,11 +191,12 @@ def reset_parameters(self, weight_initializer, bias_initializer, position_embed_ def forward(self, input_: Tensor) -> Tensor: B, C, H, W = input_.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) if self.flatten: - output = output.flatten(2).transpose(1, 2) # BCHW -> BNC + output = output.flatten(2).transpose(1, 2) # BCHW -> BNC cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = torch.cat((cls_token, output), dim=1) @@ -219,14 +223,16 @@ class VanillaClassifier(nn.Module): `init `_. """ - def __init__(self, - in_features: int, - num_classes: int, - weight: nn.Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + num_classes: int, + weight: nn.Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() self.in_features = in_features self.num_classes = num_classes @@ -236,7 +242,8 @@ def __init__(self, self.has_weight = False else: self.weight = nn.Parameter( - torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)) + torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype) + ) self.has_weight = True if bias: self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) @@ -280,7 +287,7 @@ def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): self.normalized_shape = (normalized_shape,) self.variance_epsilon = eps - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) if bias: @@ -311,20 +318,22 @@ class VanillaLinear(nn.Module): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - **kwargs) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + **kwargs, + ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.skip_bias_add = skip_bias_add - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) if bias: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) diff --git a/colossalai/legacy/nn/layer/wrapper/__init__.py b/colossalai/legacy/nn/layer/wrapper/__init__.py index c7d90d887ec6..4f3a33645344 100644 --- a/colossalai/legacy/nn/layer/wrapper/__init__.py +++ b/colossalai/legacy/nn/layer/wrapper/__init__.py @@ -1,3 +1,3 @@ from .pipeline_wrapper import PipelineSharedModuleWrapper -__all__ = ['PipelineSharedModuleWrapper'] +__all__ = ["PipelineSharedModuleWrapper"] diff --git a/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py index ec19d1b707d8..55445eb4d35a 100644 --- a/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py +++ b/colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py @@ -8,9 +8,8 @@ class PipelineSharedModuleWrapper: - def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None: - assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}' + assert len(pipeline_ranks) > 1, f"Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}" self.pipeline_ranks = pipeline_ranks self.group = None self.ranks_in_group = None @@ -33,16 +32,18 @@ def _init_group(self): self.ranks_in_group = sub_ranks def register_module(self, module: nn.Module): - assert self.ranks_in_group is not None,\ - f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + assert ( + self.ranks_in_group is not None + ), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}" src = self.ranks_in_group[self.pipeline_ranks[0]] for p in module.parameters(): - setattr(p, 'pipeline_shared_module_pg', self.group) + setattr(p, "pipeline_shared_module_pg", self.group) dist.broadcast(p, src, group=self.group) def register_parameter(self, param: nn.Parameter): - assert self.ranks_in_group is not None,\ - f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}' + assert ( + self.ranks_in_group is not None + ), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}" src = self.ranks_in_group[self.pipeline_ranks[0]] - setattr(param, 'pipeline_shared_module_pg', self.group) + setattr(param, "pipeline_shared_module_pg", self.group) dist.broadcast(param, src, group=self.group) diff --git a/colossalai/legacy/nn/loss/__init__.py b/colossalai/legacy/nn/loss/__init__.py index abb7ec3ef824..43e5a5a2e2aa 100644 --- a/colossalai/legacy/nn/loss/__init__.py +++ b/colossalai/legacy/nn/loss/__init__.py @@ -11,28 +11,27 @@ from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D _parallel_cross_entropy = { - '2d': CrossEntropyLoss2D, - '2.5d': CrossEntropyLoss2p5D, - '3d': CrossEntropyLoss3D, + "2d": CrossEntropyLoss2D, + "2.5d": CrossEntropyLoss2p5D, + "3d": CrossEntropyLoss3D, } _vocab_parallel_cross_entropy = { - '1d': VocabParallelCrossEntropyLoss1D, - '2d': VocabParallelCrossEntropyLoss2D, - '2.5d': VocabParallelCrossEntropyLoss2p5D, - '3d': VocabParallelCrossEntropyLoss3D, + "1d": VocabParallelCrossEntropyLoss1D, + "2d": VocabParallelCrossEntropyLoss2D, + "2.5d": VocabParallelCrossEntropyLoss2p5D, + "3d": VocabParallelCrossEntropyLoss3D, } class CrossEntropyLoss(_Loss): - def __init__(self, reduction: bool = True, *args, **kwargs): super().__init__() tensor_parallel = get_tensor_parallel_mode() if tensor_parallel is not None and env.vocab_parallel: self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) - elif tensor_parallel is None or tensor_parallel == '1d': - reduction = 'mean' if reduction else 'none' + elif tensor_parallel is None or tensor_parallel == "1d": + reduction = "mean" if reduction else "none" self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) else: self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs) diff --git a/colossalai/legacy/nn/loss/loss_1d.py b/colossalai/legacy/nn/loss/loss_1d.py index 2582e8b359d5..fae9c929b788 100644 --- a/colossalai/legacy/nn/loss/loss_1d.py +++ b/colossalai/legacy/nn/loss/loss_1d.py @@ -9,7 +9,6 @@ class _VocabParallelCrossEntropy1D(torch.autograd.Function): - @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, vocab_parallel_logits, targets, process_group): @@ -61,7 +60,6 @@ def forward(ctx, vocab_parallel_logits, targets, process_group): @staticmethod @custom_bwd def backward(ctx, grad_output): - # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors @@ -73,7 +71,7 @@ def backward(ctx, grad_output): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/nn/loss/loss_2d.py b/colossalai/legacy/nn/loss/loss_2d.py index 7ab58415608a..44f39a6db262 100644 --- a/colossalai/legacy/nn/loss/loss_2d.py +++ b/colossalai/legacy/nn/loss/loss_2d.py @@ -50,7 +50,7 @@ def forward(self, logits, targets): float: the loss between logits and targets. """ targets = split_batch_2d(targets) - loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() loss = reduce_by_batch_2d(loss, True) @@ -69,9 +69,9 @@ def forward(ctx, logits, targets): # vocab_parallel_logits: [b/q, s, v/q] # target: [b/q, s] logits_max = torch.max(logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW) + ) # Subtract the maximum value. # vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) logits = logits - logits_max.unsqueeze(dim=-1) @@ -90,7 +90,7 @@ def forward(ctx, logits, targets): end=logits.size()[0], ) predicted_logits = logits[arange_1d, masked_target] - predicted_logits[target_mask] = 0. + predicted_logits[target_mask] = 0.0 dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW)) exp_logits = torch.exp(logits) @@ -119,7 +119,7 @@ def backward(ctx, output_grad): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) - grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/nn/loss/loss_2p5d.py b/colossalai/legacy/nn/loss/loss_2p5d.py index 8a5d04a8c788..c57bf26e9139 100644 --- a/colossalai/legacy/nn/loss/loss_2p5d.py +++ b/colossalai/legacy/nn/loss/loss_2p5d.py @@ -47,7 +47,7 @@ def forward(self, logits, targets): targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. """ targets = split_batch_2p5d(targets) - loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() loss = reduce_by_batch_2p5d(loss, True) @@ -64,9 +64,9 @@ def forward(ctx, logits, targets): # loss: [b/dq] # targets: [b/dq, h/q] logits_max = torch.max(logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW) + ) # Subtract the maximum value. logits = logits - logits_max.unsqueeze(dim=-1) @@ -84,7 +84,7 @@ def forward(ctx, logits, targets): end=logits.size()[0], ) predicted_logits = logits[arange_1d, masked_target] - predicted_logits[target_mask] = 0. + predicted_logits[target_mask] = 0.0 dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW)) exp_logits = torch.exp(logits) @@ -113,7 +113,7 @@ def backward(ctx, output_grad): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) - grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(output_grad.unsqueeze(dim=-1)) diff --git a/colossalai/legacy/nn/loss/loss_3d.py b/colossalai/legacy/nn/loss/loss_3d.py index a576d84f71cd..988317cae3eb 100644 --- a/colossalai/legacy/nn/loss/loss_3d.py +++ b/colossalai/legacy/nn/loss/loss_3d.py @@ -49,7 +49,7 @@ def forward(self, logits, targets): """ targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) targets = split_tensor_3d(targets, 0, self.input_parallel_mode) - loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) + loss = cross_entropy(logits, targets, reduction="none", *self.loss_args, **self.loss_kwargs) if self.reduction_mean: loss = loss.mean() loss = reduce_by_batch_3d(loss, self.input_parallel_mode, self.weight_parallel_mode, True) @@ -83,7 +83,7 @@ def forward(ctx, logits, targets, output_parallel_mode): arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) predicted_logits = logits[arange_1d, masked_target] predicted_logits = predicted_logits.clone().contiguous().view_as(targets) - predicted_logits[target_mask] = 0. + predicted_logits[target_mask] = 0.0 dist.all_reduce(predicted_logits, group=gpc.get_group(output_parallel_mode)) # Loss = log(sum(exp(logits))) - predicted-logit. @@ -111,7 +111,7 @@ def backward(ctx, output_grad): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) - grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() input_grad.mul_(output_grad.unsqueeze(dim=-1)) return input_grad, None, None, None diff --git a/colossalai/legacy/nn/metric/__init__.py b/colossalai/legacy/nn/metric/__init__.py index 76c6dac89c5b..cc2b2c5d0254 100644 --- a/colossalai/legacy/nn/metric/__init__.py +++ b/colossalai/legacy/nn/metric/__init__.py @@ -8,14 +8,13 @@ from .accuracy_3d import Accuracy3D _parallel_accuracy = { - '2d': Accuracy2D, - '2.5d': Accuracy2p5D, - '3d': Accuracy3D, + "2d": Accuracy2D, + "2.5d": Accuracy2p5D, + "3d": Accuracy3D, } class Accuracy(nn.Module): - def __init__(self): super().__init__() tensor_parallel = get_tensor_parallel_mode() diff --git a/colossalai/legacy/nn/metric/accuracy_2d.py b/colossalai/legacy/nn/metric/accuracy_2d.py index 838c48834a96..59ddd5d66e20 100644 --- a/colossalai/legacy/nn/metric/accuracy_2d.py +++ b/colossalai/legacy/nn/metric/accuracy_2d.py @@ -7,8 +7,7 @@ class Accuracy2D(nn.Module): - """Accuracy for 2D parallelism - """ + """Accuracy for 2D parallelism""" def __init__(self): super().__init__() diff --git a/colossalai/legacy/nn/metric/accuracy_2p5d.py b/colossalai/legacy/nn/metric/accuracy_2p5d.py index 183380cd9846..948eae989d48 100644 --- a/colossalai/legacy/nn/metric/accuracy_2p5d.py +++ b/colossalai/legacy/nn/metric/accuracy_2p5d.py @@ -7,8 +7,7 @@ class Accuracy2p5D(nn.Module): - """Accuracy for 2p5D parallelism - """ + """Accuracy for 2p5D parallelism""" def __init__(self): super().__init__() diff --git a/colossalai/legacy/nn/metric/accuracy_3d.py b/colossalai/legacy/nn/metric/accuracy_3d.py index 675f5c2b5120..aee6118413ef 100644 --- a/colossalai/legacy/nn/metric/accuracy_3d.py +++ b/colossalai/legacy/nn/metric/accuracy_3d.py @@ -9,8 +9,7 @@ class Accuracy3D(nn.Module): - """Accuracy for 3D parallelism - """ + """Accuracy for 3D parallelism""" def __init__(self): super().__init__() @@ -26,7 +25,7 @@ def forward(self, logits, targets): Returns: float: the accuracy of prediction. - """ + """ with torch.no_grad(): targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) targets = split_tensor_3d(targets, 0, self.input_parallel_mode) diff --git a/colossalai/legacy/nn/parallel/__init__.py b/colossalai/legacy/nn/parallel/__init__.py index 17e010f478c9..19ad8404de18 100644 --- a/colossalai/legacy/nn/parallel/__init__.py +++ b/colossalai/legacy/nn/parallel/__init__.py @@ -1,5 +1,5 @@ from .data_parallel import ColoDDP __all__ = [ - 'ColoDDP', + "ColoDDP", ] diff --git a/colossalai/legacy/nn/parallel/data_parallel.py b/colossalai/legacy/nn/parallel/data_parallel.py index 2b2ad36a74f4..9634cb46a12a 100644 --- a/colossalai/legacy/nn/parallel/data_parallel.py +++ b/colossalai/legacy/nn/parallel/data_parallel.py @@ -49,11 +49,13 @@ class ColoDDP(torch.nn.Module): If it's None, the default data parallel group will be used. Defaults to None. """ - def __init__(self, - module: torch.nn.Module, - process_group: ColoProcessGroup, - bucket_cap_mb: int = 25, - rebuild_bucket: bool = True) -> None: + def __init__( + self, + module: torch.nn.Module, + process_group: ColoProcessGroup, + bucket_cap_mb: int = 25, + rebuild_bucket: bool = True, + ) -> None: assert not isinstance(module, ColoDDP) super().__init__() self.module = module @@ -74,19 +76,18 @@ def __init__(self, def parameters(self, recurse: bool = True): return self.module.parameters(recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True): + def named_parameters(self, prefix: str = "", recurse: bool = True): return self.module.named_parameters(prefix, recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True): + def named_buffers(self, prefix: str = "", recurse: bool = True): return self.module.named_buffers(prefix, recurse) def named_children(self): return self.module.named_children() - def named_modules(self, - memo: Optional[Set[torch.nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def named_modules( + self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): return self.module.named_modules(memo, prefix, remove_duplicate) def forward(self, *args, **kwargs): @@ -114,9 +115,9 @@ def grad_handle(self, p, grad): grad = grad / self.dp_world_size self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): - self.reducer.all_reduce_async(grad, - group=self.process_group.dp_process_group(), - callback_fn=partial(self._save_grad, p)) + self.reducer.all_reduce_async( + grad, group=self.process_group.dp_process_group(), callback_fn=partial(self._save_grad, p) + ) grad.record_stream(self.comm_stream) else: ColoDDP._save_grad(p, grad) @@ -130,7 +131,7 @@ def grad_handle(self, p, grad): @staticmethod def _save_grad(p, grad): - if hasattr(p, '_saved_grad'): + if hasattr(p, "_saved_grad"): p._saved_grad.add_(grad) else: p._saved_grad = grad @@ -138,7 +139,7 @@ def _save_grad(p, grad): def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True) for p in self.module.parameters(): - if getattr(p, '_saved_grad', None) is not None: + if getattr(p, "_saved_grad", None) is not None: if set_to_none: p._saved_grad = None else: @@ -167,8 +168,8 @@ def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None: for p in params_to_ignore: p._ddp_to_ignore = True - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): return self.module.load_state_dict(state_dict, strict) diff --git a/colossalai/legacy/nn/parallel/layers/__init__.py b/colossalai/legacy/nn/parallel/layers/__init__.py index f38124efedf7..2663076c6992 100644 --- a/colossalai/legacy/nn/parallel/layers/__init__.py +++ b/colossalai/legacy/nn/parallel/layers/__init__.py @@ -14,8 +14,20 @@ from .module_utils import check_colo_module, get_colo_module, init_colo_module, is_colo_module, register_colo_module __all__ = [ - 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', - 'ColoLinear', 'ColoEmbedding', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'CachedParamMgr', - 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', - 'ParallelCachedEmbeddingBagTablewiseSpiltCache' + "ColoModule", + "register_colo_module", + "is_colo_module", + "get_colo_module", + "init_colo_module", + "check_colo_module", + "ColoLinear", + "ColoEmbedding", + "CachedEmbeddingBag", + "ParallelCachedEmbeddingBag", + "CachedParamMgr", + "LimitBuffIndexCopyer", + "EvictionStrategy", + "ParallelCachedEmbeddingBagTablewise", + "TablewiseEmbeddingBagConfig", + "ParallelCachedEmbeddingBagTablewiseSpiltCache", ] diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py index d87930c1c6b3..aad6dcc5d7d8 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/__init__.py @@ -7,7 +7,12 @@ from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache __all__ = [ - 'CachedParamMgr', 'LimitBuffIndexCopyer', 'CachedEmbeddingBag', 'ParallelCachedEmbeddingBag', 'EvictionStrategy', - 'ParallelCachedEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig', - 'ParallelCachedEmbeddingBagTablewiseSpiltCache' + "CachedParamMgr", + "LimitBuffIndexCopyer", + "CachedEmbeddingBag", + "ParallelCachedEmbeddingBag", + "EvictionStrategy", + "ParallelCachedEmbeddingBagTablewise", + "TablewiseEmbeddingBagConfig", + "ParallelCachedEmbeddingBagTablewiseSpiltCache", ] diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py index 9558c541e703..3f825f11fe51 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py @@ -4,17 +4,16 @@ class BaseEmbeddingBag(abc.ABC, nn.Module): - def __init__( self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, - norm_type=2., + norm_type=2.0, scale_grad_by_freq=False, sparse=False, - mode='mean', + mode="mean", include_last_offset=False, ): super(BaseEmbeddingBag, self).__init__() @@ -22,9 +21,9 @@ def __init__( self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert padding_idx < self.num_embeddings, "Padding_idx must be within num_embeddings" elif padding_idx < 0: - assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + assert padding_idx >= -self.num_embeddings, "Padding_idx must be within num_embeddings" padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py index 16530c4ce7b8..e23864071e66 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -83,15 +83,16 @@ def __init__( if self._async_copy: self._memcpy_stream = torch.cuda.Stream() - print('use async copy') + print("use async copy") if self._evict_strategy == EvictionStrategy.LFU: # cache_row_idx -> frequency, freq of the cache rows. # classic lfu cache. evict the minimal freq value row in cuda cache. - self.register_buffer("freq_cnter", - torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(sys.maxsize), - persistent=False) + self.register_buffer( + "freq_cnter", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(sys.maxsize), + persistent=False, + ) self._elapsed_dict = {} self._show_cache_miss = True self._reset_comm_stats() @@ -142,10 +143,10 @@ def _init_weight(self, weight): if self.cuda_row_num > 0: # Enable cache with introducing auxiliary data structures self.cuda_cached_weight = torch.nn.Parameter( - torch.zeros(self.cuda_row_num, - self.embedding_dim, - device=torch.cuda.current_device(), - dtype=weight.dtype)) + torch.zeros( + self.cuda_row_num, self.embedding_dim, device=torch.cuda.current_device(), dtype=weight.dtype + ) + ) # pin memory cpu for higher CPU-GPU copy bandwidth self.weight = weight.pin_memory() if self.pin_weight else weight @@ -158,17 +159,19 @@ def _init_weight(self, weight): ) # cached_idx_map: gpu_row_idx -> cpu_row_idx - self.register_buffer("cached_idx_map", - torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) + self.register_buffer( + "cached_idx_map", + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(-1), + persistent=False, + ) # cpu_row_id -> gpu_row_idx. # gpu_row_idx as -1 means cpu_row_id not in CUDA. - self.register_buffer("inverted_cached_idx", - torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), - dtype=torch.long).fill_(-1), - persistent=False) + self.register_buffer( + "inverted_cached_idx", + torch.zeros(self.num_embeddings, device=torch.cuda.current_device(), dtype=torch.long).fill_(-1), + persistent=False, + ) self.evict_backlist = torch.tensor([], device=torch.cuda.current_device()) @@ -191,9 +194,11 @@ def cpu_weight_data(self, row_idx: int) -> torch.Tensor: torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D. """ - return self.weight.data.view(-1).narrow(0, - int(row_idx) * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + return ( + self.weight.data.view(-1) + .narrow(0, int(row_idx) * self.embedding_dim, self.embedding_dim) + .view(1, self.embedding_dim) + ) @property def cuda_available_row_num(self): @@ -238,15 +243,18 @@ def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7 preload_cpu_ids = torch.arange(preload_row_num) preload_cuda_row_idxs = preload_cpu_ids.cuda() if self.buffer_size > 0: - self.limit_buff_index_copyer.index_copy(0, - src_index=preload_cpu_ids, - tgt_index=preload_cuda_row_idxs, - src=self.weight.view(self.num_embeddings, -1), - tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + self.limit_buff_index_copyer.index_copy( + 0, + src_index=preload_cpu_ids, + tgt_index=preload_cuda_row_idxs, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1), + ) else: preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs, - preload_rows) + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_( + 0, preload_cuda_row_idxs, preload_rows + ) # update auxiliary info self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda() @@ -260,7 +268,7 @@ def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7 else: self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda() - print(f'Cache warmup finished cost {timer.elapsed} sec.') + print(f"Cache warmup finished cost {timer.elapsed} sec.") def flush(self): """flush all CUDA rows to CPU. @@ -290,18 +298,18 @@ def print_comm_stats(self): print( f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem" ) - print(f'cuda_to_cpu_elapse {elapsed} sec') + print(f"cuda_to_cpu_elapse {elapsed} sec") if self._cpu_to_cuda_numel > 0 and "5_evict_in" in self._elapsed_dict: elapsed = self._elapsed_dict["5_evict_in"] print( f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem" ) - print(f'cpu_to_cuda_elapse {elapsed} sec') + print(f"cpu_to_cuda_elapse {elapsed} sec") for k, v in self._elapsed_dict.items(): - print(f'{k}: {v}') + print(f"{k}: {v}") - print(f'cache miss ratio {self._cache_miss / self._total_cache}') + print(f"cache miss ratio {self._cache_miss / self._total_cache}") @torch.no_grad() def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor: @@ -336,10 +344,11 @@ def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor: else: cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) - assert len(cpu_row_idxs) <= self.cuda_row_num, \ - f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ - f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \ + assert len(cpu_row_idxs) <= self.cuda_row_num, ( + f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " + f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " f"Please increase cuda_row_num or decrease the training batch size." + ) self.evict_backlist = cpu_row_idxs tmp = torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True) comm_cpu_row_idxs = cpu_row_idxs[tmp] @@ -386,8 +395,9 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: # move evict in rows to gpu if self._async_copy: if self.buffer_size == 0: - evict_in_rows_gpu = self.weight.view(self.num_embeddings, - -1).index_select(0, cpu_row_idxs_copy).pin_memory() + evict_in_rows_gpu = ( + self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() + ) with torch.cuda.stream(self._memcpy_stream): evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True) else: @@ -409,9 +419,10 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: # move evict out rows to cpu if self._async_copy: - evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs) - evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select( + 0, evict_gpu_row_idxs + ) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device="cpu", pin_memory=True) with torch.cuda.stream(None): evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) @@ -425,9 +436,10 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) if self._async_copy: - evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs) - evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) + evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select( + 0, evict_gpu_row_idxs + ) + evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device="cpu", pin_memory=True) with torch.cuda.stream(None): evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) @@ -438,11 +450,13 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: with self.timer("3_evict_out") as timer: if self.buffer_size > 0: - self.limit_buff_index_copyer.index_copy(0, - src_index=evict_gpu_row_idxs, - tgt_index=evict_info.cpu(), - src=self.cuda_cached_weight.view(self.cuda_row_num, -1), - tgt=self.weight.view(self.num_embeddings, -1)) + self.limit_buff_index_copyer.index_copy( + 0, + src_index=evict_gpu_row_idxs, + tgt_index=evict_info.cpu(), + src=self.cuda_cached_weight.view(self.cuda_row_num, -1), + tgt=self.weight.view(self.num_embeddings, -1), + ) else: # allocate tmp memory on CPU and copy rows on CUDA to CPU. # TODO async gpu -> cpu @@ -450,8 +464,9 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: _wait_for_data(evict_out_rows_cpu, None) else: with self.timer("3_1_evict_out_index_select") as timer: - evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, - -1).index_select(0, evict_gpu_row_idxs) + evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select( + 0, evict_gpu_row_idxs + ) with self.timer("3_2_evict_out_gpu_to_cpu_copy") as timer: evict_out_rows_cpu = evict_out_rows_cpu.cpu() @@ -469,17 +484,19 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: # slots of cuda weight to evict in with self.timer("4_identify_cuda_slot") as timer: - slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()] + slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[: cpu_row_idxs.numel()] # TODO wait for optimize with self.timer("5_evict_in") as timer: # Here also allocate extra memory on CUDA. #cpu_row_idxs if self.buffer_size > 0: - self.limit_buff_index_copyer.index_copy(0, - src_index=cpu_row_idxs_copy, - tgt_index=slots, - src=self.weight.view(self.num_embeddings, -1), - tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) + self.limit_buff_index_copyer.index_copy( + 0, + src_index=cpu_row_idxs_copy, + tgt_index=slots, + src=self.weight.view(self.num_embeddings, -1), + tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1), + ) else: if self._async_copy: _wait_for_data(evict_in_rows_gpu, self._memcpy_stream) @@ -488,8 +505,9 @@ def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: # narrow index select to a subset of self.weight # tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1) # evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu()) - evict_in_rows_gpu = self.weight.view(self.num_embeddings, - -1).index_select(0, cpu_row_idxs_copy).pin_memory() + evict_in_rows_gpu = ( + self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() + ) with self.timer("5_2_evict_in_gpu_to_cpu_copy") as timer: evict_in_rows_gpu = evict_in_rows_gpu.cuda() @@ -537,8 +555,9 @@ def _evict(self) -> int: self.cached_idx_map.index_copy_(0, idx, buf) with Timer() as timer: - cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + cuda_tensor = torch.narrow( + self.cuda_cached_weight.view(-1), 0, max_offset * self.embedding_dim, self.embedding_dim + ).view(1, self.embedding_dim) self.cpu_weight_data(max_gpu_row_idx).data.copy_(cuda_tensor) # update inverted_cached_idx, min_slot_id is evicted from cuda @@ -570,8 +589,9 @@ def _admit(self, row_id: int): slot_offset = slot_id # copy payload from cpu to cuda with Timer() as timer: - cuda_tensor = torch.narrow(self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, - self.embedding_dim).view(1, self.embedding_dim) + cuda_tensor = torch.narrow( + self.cuda_cached_weight.view(-1), 0, slot_offset * self.embedding_dim, self.embedding_dim + ).view(1, self.embedding_dim) cuda_tensor.data.copy_(self.cpu_weight_data(row_id)) # update the inverted_cached_idx diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py index bc7d178906da..03667857b1ac 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/cached_embedding.py @@ -36,27 +36,38 @@ class CachedEmbeddingBag(BaseEmbeddingBag): evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - max_norm: float = None, - norm_type: float = 2., - scale_grad_by_freq: bool = False, - sparse: bool = False, - _weight: Optional[torch.Tensor] = None, - mode: str = 'mean', - include_last_offset: bool = False, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - cache_ratio: float = 0.01, - ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, - warmup_ratio: float = 0.7, - buffer_size: int = 0, - pin_weight: bool = False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): - super(CachedEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, - scale_grad_by_freq, sparse, mode, include_last_offset) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + max_norm: float = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[torch.Tensor] = None, + mode: str = "mean", + include_last_offset: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + cache_ratio: float = 0.01, + ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 0, + pin_weight: bool = False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU, + ): + super(CachedEmbeddingBag, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + mode, + include_last_offset, + ) assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0" self.evict_strategy = evict_strategy @@ -78,13 +89,15 @@ def _weight_alloc(self, dtype, device): weight[self.padding_idx].fill_(0) return weight - def _preprocess(self, - weight, - cuda_row_num: int, - ids_freq_mapping: Optional[List[int]] = None, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False): + def _preprocess( + self, + weight, + cuda_row_num: int, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + ): """ Called after initialized. Reorder the weight rows according to the ids_freq_mapping. @@ -95,11 +108,9 @@ def _preprocess(self, ids_freq_mapping (List[int]): a list, idx is id number, value is freq warmup_ratio (float): the amount of rows preloaded in cuda cache """ - self.cache_weight_mgr = CachedParamMgr(weight, - cuda_row_num, - buffer_size, - pin_weight, - evict_strategy=self.evict_strategy) + self.cache_weight_mgr = CachedParamMgr( + weight, cuda_row_num, buffer_size, pin_weight, evict_strategy=self.evict_strategy + ) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): @@ -107,9 +118,19 @@ def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None) with torch.no_grad(): input = self.cache_weight_mgr.prepare_ids(input) - embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - per_sample_weights, self.include_last_offset, self.padding_idx) + embeddings = F.embedding_bag( + input.cuda(), + self.cache_weight_mgr.cuda_cached_weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) if shape_hook is not None: embeddings = shape_hook(embeddings) return embeddings @@ -118,8 +139,8 @@ def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None) def weight(self): return self.cache_weight_mgr.weight - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: - yield 'weight', self.cache_weight_mgr.cuda_cached_weight + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + yield "weight", self.cache_weight_mgr.cuda_cached_weight def parameters(self, recurse: bool = True) -> Iterator[Parameter]: yield self.cache_weight_mgr.cuda_cached_weight @@ -127,8 +148,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: def set_cache_op(self, cache_op: bool = True): self.cache_op = cache_op - -############################# Perf Log ################################### + ############################# Perf Log ################################### @property def num_hits_history(self): @@ -145,14 +165,22 @@ def num_write_back_history(self): @property def swap_in_bandwidth(self): if self.cache_weight_mgr._cpu_to_cuda_numel > 0: - return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ - self.cache_weight_mgr._cpu_to_cuda_elapse + return ( + self.cache_weight_mgr._cpu_to_cuda_numel + * self.cache_weight_mgr.elem_size_in_byte + / 1e6 + / self.cache_weight_mgr._cpu_to_cuda_elapse + ) else: return 0 @property def swap_out_bandwidth(self): if self.cache_weight_mgr._cuda_to_cpu_numel > 0: - return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ - self.cache_weight_mgr._cuda_to_cpu_elapse + return ( + self.cache_weight_mgr._cuda_to_cpu_numel + * self.cache_weight_mgr.elem_size_in_byte + / 1e6 + / self.cache_weight_mgr._cuda_to_cpu_elapse + ) return 0 diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py index 804a07f88207..5e3a8df05cfe 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/copyer.py @@ -39,7 +39,7 @@ def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src for begin_pos in range(0, dim_size, self._buff_size): cur_len = min(self._buff_size, dim_size - begin_pos) src_idx_piece = src_index.narrow(0, begin_pos, cur_len) - if src_device.type == 'cpu' and tgt_device.type == 'cuda': + if src_device.type == "cpu" and tgt_device.type == "cuda": cpu_tmp_buffer = src.index_select(dim, src_idx_piece).pin_memory() tmp_buffer = torch.empty_like(cpu_tmp_buffer, device=tgt_device) tmp_buffer.copy_(cpu_tmp_buffer) diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py index 36e04c833feb..ceaa9081c724 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/embedding_config.py @@ -2,22 +2,24 @@ class TablewiseEmbeddingBagConfig: - ''' + """ example: def prepare_tablewise_config(args, cache_ratio, ...): embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] ... return embedding_bag_config_list - ''' + """ - def __init__(self, - num_embeddings: int, - cuda_row_num: int, - assigned_rank: int = 0, - buffer_size=50_000, - ids_freq_mapping=None, - initial_weight: torch.tensor = None, - name: str = ""): + def __init__( + self, + num_embeddings: int, + cuda_row_num: int, + assigned_rank: int = 0, + buffer_size=50_000, + ids_freq_mapping=None, + initial_weight: torch.tensor = None, + name: str = "", + ): self.num_embeddings = num_embeddings self.cuda_row_num = cuda_row_num self.assigned_rank = assigned_rank diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py index 522fb4f4497f..ee739935fef2 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py @@ -1,13 +1,13 @@ -from typing import Iterator, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.nn.functional as F from colossalai.legacy.nn._ops._utils import dual_all_to_all from colossalai.legacy.tensor import ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec -from colossalai.tensor import ColoParameter, ColoTensor +from colossalai.tensor import ColoTensor -from .cache_mgr import CachedParamMgr, EvictionStrategy +from .cache_mgr import EvictionStrategy from .cached_embedding import CachedEmbeddingBag @@ -15,9 +15,9 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: if world_size == 1: return 0, embedding_dim, True - assert embedding_dim >= world_size, \ - f"Embedding dimension {embedding_dim} must be larger than the world size " \ - f"{world_size} of the process group" + assert embedding_dim >= world_size, ( + f"Embedding dimension {embedding_dim} must be larger than the world size " f"{world_size} of the process group" + ) chunk_size = embedding_dim // world_size threshold = embedding_dim % world_size # if embedding dim is divisible by world size @@ -31,37 +31,55 @@ def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: class ParallelCachedEmbeddingBag(CachedEmbeddingBag): - - def __init__(self, - num_embeddings, - embedding_dim, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - _weight=None, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - cache_ratio=0.01, - ids_freq_mapping=None, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.DATASET): + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode="mean", + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + ids_freq_mapping=None, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, + ): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() self.partition_start_index, self.partition_end_index, divisible = get_partition( - embedding_dim, self.rank, self.world_size) + embedding_dim, self.rank, self.world_size + ) self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index - super(ParallelCachedEmbeddingBag, - self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, - warmup_ratio, buffer_size, pin_weight, evict_strategy) + super(ParallelCachedEmbeddingBag, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + mode, + include_last_offset, + dtype, + device, + cache_ratio, + ids_freq_mapping, + warmup_ratio, + buffer_size, + pin_weight, + evict_strategy, + ) self.cache_op = True def _weight_alloc(self, dtype, device): @@ -70,9 +88,11 @@ def _weight_alloc(self, dtype, device): weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) if self.padding_idx is not None: weight[self.padding_idx].fill_(0) - colo_tensor_spec = ColoTensorSpec(pg=ProcessGroup(tp_degree=self.world_size), - dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), - compute_attr=ComputePattern.TP1D) + colo_tensor_spec = ColoTensorSpec( + pg=ProcessGroup(tp_degree=self.world_size), + dist_attr=ShardSpec(dims=[-1], num_partitions=[self.world_size]), + compute_attr=ComputePattern.TP1D, + ) return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) def forward( @@ -87,15 +107,24 @@ def forward( if self.cache_op: with torch.no_grad(): indices = self.cache_weight_mgr.prepare_ids(indices) - output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - per_sample_weights, self.include_last_offset, self.padding_idx) + output_shard = F.embedding_bag( + indices.cuda(), + self.cache_weight_mgr.cuda_cached_weight, + offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) if shape_hook is not None: output_shard = shape_hook(output_shard) - output_full = dual_all_to_all(output_shard, - self.weight.get_process_group(), - scatter_dim=scatter_dim, - gather_dim=gather_dim) + output_full = dual_all_to_all( + output_shard, self.weight.get_process_group(), scatter_dim=scatter_dim, gather_dim=gather_dim + ) return output_full def set_cache_op(self, cache_op: bool = True): @@ -108,31 +137,33 @@ def from_pretrained( freeze: bool = True, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, - norm_type: float = 2., + norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, - mode: str = 'mean', + mode: str = "mean", include_last_offset: bool = False, cuda_row_num: int = 100_000, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio: float = 0.7, buffer_size: int = 0, - ) -> 'ParallelCachedEmbeddingBag': + ) -> "ParallelCachedEmbeddingBag": rows, cols = embedding.shape - embedding_bag = cls(rows, - cols, - padding_idx, - max_norm, - norm_type, - scale_grad_by_freq, - sparse, - embedding, - mode, - include_last_offset, - cuda_row_num=cuda_row_num, - ids_freq_mapping=ids_freq_mapping, - warmup_ratio=warmup_ratio, - buffer_size=buffer_size) + embedding_bag = cls( + rows, + cols, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + embedding, + mode, + include_last_offset, + cuda_row_num=cuda_row_num, + ids_freq_mapping=ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=buffer_size, + ) embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze return embedding_bag diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py index a1feda2bdb0e..7d21f5b68ce6 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py @@ -1,4 +1,3 @@ -import time from typing import List import torch @@ -19,24 +18,26 @@ class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag): Those parameters in TablewiseEmbeddingBagConfig are ignored: cuda_row_num, buffer_size, initial_weight. """ - def __init__(self, - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], - embedding_dim: int, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - _weight=None, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - cache_ratio=0.01, - warmup_ratio=0.7, - buffer_size=50_000, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + def __init__( + self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode="mean", + include_last_offset=False, + dtype=None, + device=None, + cache_ratio=0.01, + warmup_ratio=0.7, + buffer_size=50_000, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU, + ): self.rank = dist.get_rank() self.world_size = dist.get_world_size() self.rank_of_tables = [config.assigned_rank for config in embedding_bag_config_list] @@ -62,11 +63,27 @@ def __init__(self, break self.cache_ratio = cache_ratio # table-associate cache - cuda_row_num = int(cache_ratio * self.num_embeddings) - super(ParallelCachedEmbeddingBagTablewise, - self).__init__(self.num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, - sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, - warmup_ratio, buffer_size, pin_weight, evict_strategy) + int(cache_ratio * self.num_embeddings) + super(ParallelCachedEmbeddingBagTablewise, self).__init__( + self.num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + mode, + include_last_offset, + dtype, + device, + cache_ratio, + ids_freq_mapping, + warmup_ratio, + buffer_size, + pin_weight, + evict_strategy, + ) # for assigned tables reconnection: self.idx_offset_list = [] @@ -96,7 +113,8 @@ def forward( # not recommanded. it takes time. batch_size = (offsets.shape[0]) // self.global_tables_num local_indices, local_offsets, local_per_sample_weights = self.split_along_rank( - batch_size, indices, offsets, per_sample_weights) + batch_size, indices, offsets, per_sample_weights + ) else: # recommanded. batch_size = (offsets.shape[0]) // len(self.assigned_table_list) @@ -104,9 +122,19 @@ def forward( if self.cache_op: with torch.no_grad(): indices = self.cache_weight_mgr.prepare_ids(local_indices) - local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, - self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, - local_per_sample_weights, self.include_last_offset, self.padding_idx) + local_output = F.embedding_bag( + indices.cuda(), + self.cache_weight_mgr.cuda_cached_weight, + local_offsets, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.mode, + self.sparse, + local_per_sample_weights, + self.include_last_offset, + self.padding_idx, + ) local_output = torch.cat(local_output.split(batch_size), 1) remains = batch_size % self.world_size scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] @@ -115,21 +143,19 @@ def forward( output_full = shape_hook(output_full) return output_full - def split_along_rank(self, - batch_size, - indices: torch.Tensor, - offsets: torch.Tensor = None, - per_sample_weights=None): - ''' + def split_along_rank( + self, batch_size, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None + ): + """ if input indices and offsets haven't been splitted along assigned rank, this function will do it. it takes time. please consider splitting data during batch loading. - ''' + """ local_indices_list: List(torch.Tensor) = [] local_offsets_list: List(torch.Tensor) = [] if per_sample_weights != None: local_per_sample_weights_list: List(torch.Tensor) = [] - offset_pre_end = 0 # local_offsets trick + offset_pre_end = 0 # local_offsets trick for i, handle_table in enumerate(self.assigned_table_list): indices_start_position = offsets[batch_size * handle_table] if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): @@ -138,7 +164,7 @@ def split_along_rank(self, else: indices_end_position = offsets[batch_size * (handle_table + 1)] # alternative approach: reduce malloc - ''' + """ # 1. local_indices_list: local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position) torch.sub(local_indices, self.idx_offset_list[i], out=local_indices) @@ -158,25 +184,29 @@ def split_along_rank(self, torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets) offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder local_offsets_list.append(local_offsets) - ''' + """ # 1. local_indices_list: local_indices_list.append( - indices.narrow(0, indices_start_position, - indices_end_position - indices_start_position).sub(self.idx_offset_list[i])) + indices.narrow(0, indices_start_position, indices_end_position - indices_start_position).sub( + self.idx_offset_list[i] + ) + ) # 2. local_offsets_list: if i + 1 == len(self.assigned_table_list): # till-the-end special case if not self.include_last_offset: - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).add(offset_pre_end - offsets[batch_size * - (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size).add( + offset_pre_end - offsets[batch_size * (handle_table)] + ) else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + - 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).add( + offset_pre_end - offsets[batch_size * (handle_table)] + ) local_offsets_list.append(local_offsets) else: - local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + - 1).add(offset_pre_end - offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).add( + offset_pre_end - offsets[batch_size * (handle_table)] + ) offset_pre_end = local_offsets[-1] local_offsets_list.append(local_offsets[:-1]) # 3. local_per_sample_weights_list: diff --git a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py index 8017ee72b0b4..94a27a8673da 100644 --- a/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py +++ b/colossalai/legacy/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py @@ -19,21 +19,23 @@ class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): every table assigned to this class instance is managed by a CachedEmbeddingBag. """ - def __init__(self, - embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], - embedding_dim: int, - padding_idx=None, - max_norm=None, - norm_type=2., - scale_grad_by_freq=False, - sparse=False, - mode='mean', - include_last_offset=False, - dtype=None, - device=None, - warmup_ratio=0.7, - pin_weight=False, - evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + def __init__( + self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + mode="mean", + include_last_offset=False, + dtype=None, + device=None, + warmup_ratio=0.7, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU, + ): super(ParallelCachedEmbeddingBagTablewiseSpiltCache, self).__init__() self.rank = dist.get_rank() self.world_size = dist.get_world_size() @@ -56,24 +58,27 @@ def __init__(self, if config.assigned_rank != self.rank: continue self.cached_embedding_bag_list.append( - CachedEmbeddingBag(num_embeddings=config.num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=config.initial_weight, - mode=mode, - include_last_offset=include_last_offset, - dtype=dtype, - device=device, - cuda_row_num=config.cuda_row_num, - ids_freq_mapping=config.ids_freq_mapping, - warmup_ratio=warmup_ratio, - buffer_size=config.buffer_size, - pin_weight=pin_weight, - evict_strategy=evict_strategy)) + CachedEmbeddingBag( + num_embeddings=config.num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=config.initial_weight, + mode=mode, + include_last_offset=include_last_offset, + dtype=dtype, + device=device, + cuda_row_num=config.cuda_row_num, + ids_freq_mapping=config.ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=config.buffer_size, + pin_weight=pin_weight, + evict_strategy=evict_strategy, + ) + ) # prepare list shape for all_to_all output self.embedding_dim_per_rank = [0 for i in range(self.world_size)] @@ -95,22 +100,26 @@ def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sampl indices_end_position = offsets[batch_size * (handle_table + 1)] with record_function("part 2"): # local_indices = indices[indices_start_position:indices_end_position] - self.global_tables_offsets[handle_table] - local_indices = indices.narrow(0, indices_start_position, indices_end_position - - indices_start_position).sub(self.global_tables_offsets[handle_table]) + local_indices = indices.narrow( + 0, indices_start_position, indices_end_position - indices_start_position + ).sub(self.global_tables_offsets[handle_table]) if self.include_last_offset: # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size + 1).sub(offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1).sub( + offsets[batch_size * (handle_table)] + ) else: # local_offsets = offsets[batch_size * handle_table:batch_size * (handle_table + 1)] - offsets[batch_size * (handle_table)] - local_offsets = offsets.narrow(0, batch_size * handle_table, - batch_size).sub(offsets[batch_size * (handle_table)]) + local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size).sub( + offsets[batch_size * (handle_table)] + ) local_per_sample_weights = None if per_sample_weights != None: local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] with record_function("(tablewise) tablewise forward"): - local_output_list.append(self.cached_embedding_bag_list[i](local_indices, local_offsets, - local_per_sample_weights)) + local_output_list.append( + self.cached_embedding_bag_list[i](local_indices, local_offsets, local_per_sample_weights) + ) # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) local_output = torch.cat(local_output_list, 1) diff --git a/colossalai/legacy/nn/parallel/layers/colo_module.py b/colossalai/legacy/nn/parallel/layers/colo_module.py index 69d92afaaa94..df0b324eeeb8 100644 --- a/colossalai/legacy/nn/parallel/layers/colo_module.py +++ b/colossalai/legacy/nn/parallel/layers/colo_module.py @@ -5,7 +5,6 @@ class ColoModule(object): - def __init__(self): self._shard_params: List[str] = [] self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {} @@ -13,18 +12,18 @@ def __init__(self): def _register_shard_params(self, params: List[str]): self._shard_params = params - def _register_allowed_patterns(self, - compute_pattern: ComputePattern, - dist_specs: Dict[str, _DistSpec], - mode='default'): - assert list( - dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.' + def _register_allowed_patterns( + self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode="default" + ): + assert ( + list(dist_specs.keys()).sort() == self._shard_params.sort() + ), "Every registered param should have dist_spec." if not compute_pattern in self._allowed_patterns: self._allowed_patterns[compute_pattern] = {} self._allowed_patterns[compute_pattern][mode] = dist_specs def _set_default(self, compute_pattern: ComputePattern, target_mode): - self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_mode] + self._allowed_patterns[compute_pattern]["default"] = self._allowed_patterns[compute_pattern][target_mode] def has_compute_pattern(self, compute_pattern: ComputePattern): return compute_pattern in self._allowed_patterns @@ -33,10 +32,10 @@ def get_dist_specs(self, compute_pattern: ComputePattern): assert self.has_compute_pattern(compute_pattern) return self._allowed_patterns[compute_pattern] - def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode='default'): + def has_compute_pattern_with_mode(self, compute_pattern: ComputePattern, mode="default"): return compute_pattern in self._allowed_patterns and mode in self._allowed_patterns[compute_pattern] - def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode='default'): + def get_dist_specs_with_mode(self, compute_pattern: ComputePattern, mode="default"): assert self.has_compute_pattern_with_mode(compute_pattern, mode) return self._allowed_patterns[compute_pattern][mode] diff --git a/colossalai/legacy/nn/parallel/layers/embedding.py b/colossalai/legacy/nn/parallel/layers/embedding.py index 4796699fc57f..f204f3fb71f0 100644 --- a/colossalai/legacy/nn/parallel/layers/embedding.py +++ b/colossalai/legacy/nn/parallel/layers/embedding.py @@ -1,13 +1,12 @@ -from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec from .colo_module import ColoModule class ColoEmbedding(ColoModule): - def __init__(self): super(ColoEmbedding, self).__init__() - self._register_shard_params(['weight']) + self._register_shard_params(["weight"]) def register(self, compute_pattern, pg: ProcessGroup): if not compute_pattern in self._allowed_patterns: @@ -20,18 +19,18 @@ def _set_TP1D(self, pg: ProcessGroup): self._register_allowed_patterns( compute_pattern=_compute_pattern, dist_specs={ - 'weight': ShardSpec([0], [pg.tp_world_size()]), + "weight": ShardSpec([0], [pg.tp_world_size()]), }, - mode='row', + mode="row", ) # TP1D Col Linear self._register_allowed_patterns( compute_pattern=_compute_pattern, dist_specs={ - 'weight': ShardSpec([-1], [pg.tp_world_size()]), + "weight": ShardSpec([-1], [pg.tp_world_size()]), }, - mode='col', + mode="col", ) - self._set_default(compute_pattern=_compute_pattern, target_mode='row') + self._set_default(compute_pattern=_compute_pattern, target_mode="row") diff --git a/colossalai/legacy/nn/parallel/layers/linear.py b/colossalai/legacy/nn/parallel/layers/linear.py index 51a8d4c976a6..c3b6df1ec9da 100644 --- a/colossalai/legacy/nn/parallel/layers/linear.py +++ b/colossalai/legacy/nn/parallel/layers/linear.py @@ -1,13 +1,12 @@ -from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec +from colossalai.legacy.tensor import ComputePattern, ProcessGroup, ShardSpec from .colo_module import ColoModule class ColoLinear(ColoModule): - def __init__(self): super(ColoLinear, self).__init__() - self._register_shard_params(['weight', 'bias']) + self._register_shard_params(["weight", "bias"]) def register(self, compute_pattern, pg: ProcessGroup): if not compute_pattern in self._allowed_patterns: @@ -19,21 +18,15 @@ def _set_TP1D(self, pg): _compute_pattern = ComputePattern.TP1D self._register_allowed_patterns( compute_pattern=_compute_pattern, - dist_specs={ - 'weight': ShardSpec([-1], [pg.tp_world_size()]), - 'bias': None - }, - mode='row', + dist_specs={"weight": ShardSpec([-1], [pg.tp_world_size()]), "bias": None}, + mode="row", ) # TP1D Col Linear self._register_allowed_patterns( compute_pattern=_compute_pattern, - dist_specs={ - 'weight': ShardSpec([0], [pg.tp_world_size()]), - 'bias': ShardSpec([0], [pg.tp_world_size()]) - }, - mode='col', + dist_specs={"weight": ShardSpec([0], [pg.tp_world_size()]), "bias": ShardSpec([0], [pg.tp_world_size()])}, + mode="col", ) - self._set_default(compute_pattern=_compute_pattern, target_mode='row') + self._set_default(compute_pattern=_compute_pattern, target_mode="row") diff --git a/colossalai/legacy/nn/parallel/layers/module_utils.py b/colossalai/legacy/nn/parallel/layers/module_utils.py index 09326d2d6f9a..4dbce7e09f37 100644 --- a/colossalai/legacy/nn/parallel/layers/module_utils.py +++ b/colossalai/legacy/nn/parallel/layers/module_utils.py @@ -2,7 +2,7 @@ import torch -from colossalai.legacy.tensor import ComputeSpec, ProcessGroup, distspec +from colossalai.legacy.tensor import ComputeSpec, ProcessGroup from colossalai.tensor import ColoParameter from . import ColoModule @@ -41,7 +41,7 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) for param_name in param_names: param = module.get_parameter(param_name) if not isinstance(param, ColoParameter): - raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.') + raise Exception(f"Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.") if param.has_compute_spec(): cur_compute_pattern = param.compute_spec.compute_pattern if compute_pattern is None: @@ -49,7 +49,8 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) else: if cur_compute_pattern != compute_pattern: raise Exception( - f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.') + f"Invalid ColoParameter spec: Params in {module} have different compute_pattern." + ) else: continue @@ -57,7 +58,8 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) colo_module.register(compute_pattern, pg) if not colo_module.has_compute_pattern(compute_pattern): raise Exception( - f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.') + f"Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed." + ) match_specs = False allowed_specs = colo_module.get_dist_specs(compute_pattern) @@ -77,17 +79,15 @@ def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True) match_specs = True break if match_specs == False: - raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.') + raise Exception(f"Invalid ColoParameter spec: Params in {module} are incorrectly sharded.") if recursive == True: for submodule in module.children(): check_colo_module(submodule, pg=pg, recursive=True) -def init_colo_module(module: torch.nn.Module, - compute_spec: ComputeSpec, - pg: ProcessGroup, - recursive=True, - mode='default'): +def init_colo_module( + module: torch.nn.Module, compute_spec: ComputeSpec, pg: ProcessGroup, recursive=True, mode="default" +): compute_pattern = compute_spec.compute_pattern if is_colo_module(module): # for each param diff --git a/colossalai/legacy/nn/parallel/reducer.py b/colossalai/legacy/nn/parallel/reducer.py index 5687055819fe..7b3d283e47dd 100644 --- a/colossalai/legacy/nn/parallel/reducer.py +++ b/colossalai/legacy/nn/parallel/reducer.py @@ -13,7 +13,6 @@ class Bucket: - def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): self.buffer = torch.zeros(size, dtype=dtype, device=device) self.group = group @@ -26,7 +25,7 @@ def flush(self) -> None: assert len(self.callbacks) == 0 return # reduce-scatter bucket - dist.all_reduce(self.buffer[:self.offset], group=self.group) + dist.all_reduce(self.buffer[: self.offset], group=self.group) # execute post-reduction callbacks for callback_fn in self.callbacks: @@ -37,24 +36,22 @@ def flush(self) -> None: self.buffer = torch.zeros_like(self.buffer) def alloc(self) -> None: - if self.buffer.storage().size() == 0: self.buffer.storage().resize_(self.buffer.numel()) def free(self) -> None: - assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" self.buffer.storage().resize_(0) def append(self, tensor: Tensor, callback_fn: Callable): tensor_size = tensor.numel() offset = self.offset - self.buffer[offset:offset + tensor_size].copy_(tensor.flatten()) + self.buffer[offset : offset + tensor_size].copy_(tensor.flatten()) self.offset += tensor_size # callback will be given the reduced result if callback_fn is not None: - result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape) + result_view = self.buffer[offset : offset + tensor_size].view(tensor.shape) self.callbacks.append(functools.partial(callback_fn, result_view)) @property @@ -63,7 +60,6 @@ def avail_size(self) -> int: class Reducer: - def __init__(self, bucket_size_mb: int = 25): self.bucket_size_mb = bucket_size_mb self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} @@ -101,7 +97,7 @@ def free(self) -> None: @functools.lru_cache() def _get_bucket_size(self, element_size: int) -> int: - if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. return 0 MB = 1024 * 1024 bucket_size = self.bucket_size_mb * MB / element_size diff --git a/colossalai/legacy/pipeline/__init__.py b/colossalai/legacy/pipeline/__init__.py index f36f54ac9307..9f1a5ec7fd1f 100644 --- a/colossalai/legacy/pipeline/__init__.py +++ b/colossalai/legacy/pipeline/__init__.py @@ -1,4 +1,4 @@ from .layer_spec import LayerSpec from .pipelinable import PipelinableContext, PipelinableModel -__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] +__all__ = ["PipelinableModel", "PipelinableContext", "LayerSpec"] diff --git a/colossalai/legacy/pipeline/layer_spec.py b/colossalai/legacy/pipeline/layer_spec.py index 3960debd7f72..825816e1c032 100644 --- a/colossalai/legacy/pipeline/layer_spec.py +++ b/colossalai/legacy/pipeline/layer_spec.py @@ -4,9 +4,7 @@ class LayerSpec: - """ - - """ + """ """ def __init__(self, typename, *module_args, **module_kwargs): self.typename = typename @@ -16,7 +14,7 @@ def __init__(self, typename, *module_args, **module_kwargs): self._param_count = 0 if not issubclass(typename, torch.nn.Module): - raise RuntimeError('LayerSpec only supports torch.nn.Module types.') + raise RuntimeError("LayerSpec only supports torch.nn.Module types.") def __repr__(self): return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) diff --git a/colossalai/legacy/pipeline/middleware/__init__.py b/colossalai/legacy/pipeline/middleware/__init__.py index 481741bfee31..8a678b7b4c87 100644 --- a/colossalai/legacy/pipeline/middleware/__init__.py +++ b/colossalai/legacy/pipeline/middleware/__init__.py @@ -1,3 +1,3 @@ from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo -__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] +__all__ = ["Topo", "Partition", "PartitionOutputVal", "PartitionInputVal"] diff --git a/colossalai/legacy/pipeline/middleware/adaptor/__init__.py b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py index 0b0d36d2ffe5..7f2b18670a76 100644 --- a/colossalai/legacy/pipeline/middleware/adaptor/__init__.py +++ b/colossalai/legacy/pipeline/middleware/adaptor/__init__.py @@ -1,3 +1,3 @@ from .fx import get_topology as get_fx_topology -__all__ = ['get_fx_topology'] +__all__ = ["get_fx_topology"] diff --git a/colossalai/legacy/pipeline/middleware/adaptor/fx.py b/colossalai/legacy/pipeline/middleware/adaptor/fx.py index 8cc40f120f15..34b21f8be1bb 100644 --- a/colossalai/legacy/pipeline/middleware/adaptor/fx.py +++ b/colossalai/legacy/pipeline/middleware/adaptor/fx.py @@ -10,7 +10,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False): elif is_output: partition_id = 1 else: - prefix = 'submod_' + prefix = "submod_" partition_id = int(partition_name.split(prefix)[-1]) + 2 return partition_id @@ -27,10 +27,10 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False): def find_input_in_partition(node, partitions, input_partitions=None): p_input_val = None - direct_def = not node.name.startswith('getitem') + direct_def = not node.name.startswith("getitem") # search in input if direct_def and input_partitions is not None: - partition_id = partition_name_to_id('', is_input=True) + partition_id = partition_name_to_id("", is_input=True) for i, input_node in enumerate(input_partitions): if input_node == node: p_input_val = PartitionInputVal(partition_id=partition_id, offset=i) @@ -57,7 +57,7 @@ def find_input_in_partition(node, partitions, input_partitions=None): def find_output_in_partition(node, partitions, output_partitions=None): p_output_val = PartitionOutputVal() for user in node.users: - direct_use = not user.name.startswith('getitem') + direct_use = not user.name.startswith("getitem") # user is mid partition for partition in partitions: # direct call @@ -82,7 +82,7 @@ def find_output_in_partition(node, partitions, output_partitions=None): output_node = output_partitions[0] if user.op == output_node.op: output_keys = {} - partition_id = partition_name_to_id('', is_output=True) + partition_id = partition_name_to_id("", is_output=True) torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n)) for i, arg in enumerate(output_keys): if arg == node: @@ -99,11 +99,11 @@ def get_topology(gm: GraphModule): partitions = [] output_partitions = [] for node in gm.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": input_partitions.append(node) - elif node.name.startswith('submod_'): + elif node.name.startswith("submod_"): partitions.append(node) - elif node.op == 'output': + elif node.op == "output": output_partitions.append(node) else: continue @@ -127,7 +127,7 @@ def get_topology(gm: GraphModule): # set output for submodule direct_use = True for user in partition.users: - if user.name.startswith('getitem'): + if user.name.startswith("getitem"): direct_use = False break if direct_use: @@ -146,7 +146,8 @@ def get_topology(gm: GraphModule): topo_output_partition = Partition() torch.fx.graph.map_arg( partition.args[0], - lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions))) + lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)), + ) topo.set_partitions(partition_id=1, partition=topo_output_partition) topo.set_output_partition_id(partition_id=1) diff --git a/colossalai/legacy/pipeline/middleware/topo.py b/colossalai/legacy/pipeline/middleware/topo.py index 3c21cce6dc0e..d0e3d2c3dedf 100644 --- a/colossalai/legacy/pipeline/middleware/topo.py +++ b/colossalai/legacy/pipeline/middleware/topo.py @@ -10,7 +10,7 @@ class ValPosition: offset: int def __str__(self) -> str: - res = f'[partition_id:{self.partition_id},offset:{self.offset}]' + res = f"[partition_id:{self.partition_id},offset:{self.offset}]" return res def __repr__(self) -> str: @@ -18,7 +18,6 @@ def __repr__(self) -> str: class PartitionInputVal(object): - def __init__(self, partition_id, offset) -> None: # every input from which partition_id and which offset val_pos = ValPosition(partition_id, offset) @@ -28,8 +27,8 @@ def get(self): return self._from_partition_and_offset def __str__(self) -> str: - res = '' - res += f'<-({self._from_partition_and_offset})' + res = "" + res += f"<-({self._from_partition_and_offset})" return res def __repr__(self) -> str: @@ -37,7 +36,6 @@ def __repr__(self) -> str: class PartitionOutputVal(object): - def __init__(self) -> None: # every output to which partition_id and which offset self._to_partition_and_offset: List[ValPosition] = [] @@ -50,11 +48,11 @@ def get(self): return self._to_partition_and_offset def __str__(self) -> str: - res = '' - res += '->(' + res = "" + res += "->(" for val_pos in self._to_partition_and_offset: - res += f'{val_pos},' - res += ')' + res += f"{val_pos}," + res += ")" return res def __repr__(self) -> str: @@ -62,7 +60,6 @@ def __repr__(self) -> str: class Partition(object): - def __init__(self) -> None: self._input_vals: List[PartitionInputVal] = [] self._output_vals: List[PartitionOutputVal] = [] @@ -110,16 +107,16 @@ def get_output_partition_ids(self): return res def __str__(self) -> str: - res = '' - res += f' input:\n' - res += f' length:{len(self._input_vals)}\n' + res = "" + res += f" input:\n" + res += f" length:{len(self._input_vals)}\n" for i, input_val in enumerate(self._input_vals): - res += f' offset={i}:{input_val}\n' + res += f" offset={i}:{input_val}\n" - res += f' output:\n' - res += f' length:{len(self._output_vals)}\n' + res += f" output:\n" + res += f" length:{len(self._output_vals)}\n" for i, output_val in enumerate(self._output_vals): - res += f' offset={i}:{output_val}\n' + res += f" offset={i}:{output_val}\n" return res @@ -140,7 +137,6 @@ def __repr__(self) -> str: # _input_partition_id: the key represents input_partition # _output_partition_id: the key represents output_partition class Topo(object): - def __init__(self, input_partition_id=None, output_partition_id=None) -> None: self._partitions: Dict[int, Partition] = {} self._input_partition_id = input_partition_id @@ -162,7 +158,7 @@ def set_partitions(self, partition_id: int, partition: Partition): self._partitions[partition_id] = partition def get_mid_partitions(self): - res = {} #{partition_id: Partition} + res = {} # {partition_id: Partition} for partition_id, partition in self._partitions.items(): if self._input_partition_id == partition_id or self._output_partition_id == partition_id: continue @@ -186,27 +182,27 @@ def get_partition_by_id(self, partition_id): return self._partitions[partition_id] def __str__(self) -> str: - res = '' + res = "" if len(self._partitions) == 0: - return 'Empty Topo Graph.' + return "Empty Topo Graph." input_part = self.get_input_partition() if input_part is not None: - res += '{\n' - res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}' - res += '}\n' + res += "{\n" + res += f"InputPartition:\n partition_id={self._input_partition_id}\n{input_part}" + res += "}\n" mid_parts = self.get_mid_partitions() for i, (partition_id, part) in enumerate(mid_parts.items()): - res += '{\n' - res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}' - res += '}\n' + res += "{\n" + res += f"SubPartition_{i}:\n partition_id={partition_id}\n {part}" + res += "}\n" output_part = self.get_output_partition() if output_part is not None: - res += '{\n' - res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}' - res += '}\n' + res += "{\n" + res += f"OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}" + res += "}\n" return res diff --git a/colossalai/legacy/pipeline/pipelinable.py b/colossalai/legacy/pipeline/pipelinable.py index e74cad0ad1b0..82ccdb554527 100644 --- a/colossalai/legacy/pipeline/pipelinable.py +++ b/colossalai/legacy/pipeline/pipelinable.py @@ -132,8 +132,8 @@ def to_layer_list(self, exec_seq=None): for child in self._root_children: layer_spec = self._layer_spec_dict[id(child)] if layer_spec.typename in ( - torch.nn.modules.container.ModuleList, - torch.nn.modules.container.Sequential, + torch.nn.modules.container.ModuleList, + torch.nn.modules.container.Sequential, ): for child_in_container in layer_spec.children: self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)]) @@ -198,8 +198,9 @@ def partition(self, num_chunks, pipeline_size, rank): param_counts.append(layer_spec.count_params()) parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank] elif self._policy == "customized": - assert (self._exec_seq - is not None), f"An explicit exec_seq must be defined by user in customized policy mode." + assert ( + self._exec_seq is not None + ), f"An explicit exec_seq must be defined by user in customized policy mode." self.customized_parts = customized_partition(self._exec_seq) assert len(self.customized_parts) == gpc.get_world_size( ParallelMode.PIPELINE @@ -226,14 +227,14 @@ def partition(self, num_chunks, pipeline_size, rank): elif (layer, "behind") in self._func_dict: behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")] module_list_in_partition = torch.nn.ModuleList(module_list_in_partition) - pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition, - behind_func_dict_in_partition) + pipeline_model = PipelinableModel( + module_list_in_partition, front_func_dict_in_partition, behind_func_dict_in_partition + ) return pipeline_model class PipelinableModel(torch.nn.Module): - def __init__(self, module_list, front_func_dict, behind_func_dict): super().__init__() self._module_list = module_list diff --git a/colossalai/legacy/pipeline/pipeline_process_group.py b/colossalai/legacy/pipeline/pipeline_process_group.py index 1168158defaf..2d0d5be87cac 100644 --- a/colossalai/legacy/pipeline/pipeline_process_group.py +++ b/colossalai/legacy/pipeline/pipeline_process_group.py @@ -1,6 +1,5 @@ -import os import threading -from typing import Dict, List, Tuple +from typing import List import torch.distributed as dist from torch.distributed import rpc @@ -14,14 +13,15 @@ class PipelineProcessGroup: def __init__(self) -> None: self.is_initialize = False - def set_global_info(self, - rank: int, - world_size: int, - dp_degree: int = 1, - tp_degree: int = 1, - num_worker_threads: int = 1, - device: str = "cuda") -> None: - + def set_global_info( + self, + rank: int, + world_size: int, + dp_degree: int = 1, + tp_degree: int = 1, + num_worker_threads: int = 1, + device: str = "cuda", + ) -> None: device_mesh_size = dp_degree * tp_degree assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!" self._num_worker_threads = num_worker_threads @@ -60,8 +60,8 @@ def _initialize_process_group(self): device = self.device world_size = self.get_world_size() rank = self.get_global_rank() - backend = 'nccl' if device == 'cuda' else 'gloo' - dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group') + backend = "nccl" if device == "cuda" else "gloo" + dist.init_process_group(backend, world_size=world_size, rank=rank, group_name="main_group") def _initialize_pp_process_group(self) -> None: rank = self.get_global_rank() @@ -71,9 +71,9 @@ def _initialize_pp_process_group(self) -> None: options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads) for pp_rank in self._pp_ranks: - options.set_device_map(f'work{pp_rank}', {rank: pp_rank}) + options.set_device_map(f"work{pp_rank}", {rank: pp_rank}) - rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + rpc.init_rpc(name=f"work{rank}", rank=rank, world_size=world_size, rpc_backend_options=options) def _initialize_tp_dp_process_group(self) -> None: rank = self.get_global_rank() @@ -147,10 +147,10 @@ def get_tp_global_ranks(self): def get_chimera_all_reduce_group(self, pp_rank: int): with self.chimera_lock: - if not hasattr(self, 'chimera_groups'): + if not hasattr(self, "chimera_groups"): world_size = self.get_world_size() stage_num = self.get_stage_num() - assert world_size % 2 == 0, 'world_size must be even in chimera!' + assert world_size % 2 == 0, "world_size must be even in chimera!" self.chimera_groups = {} for rank in range(world_size // 2): pair = [rank, world_size - 1 - rank] diff --git a/colossalai/legacy/pipeline/rpc/__init__.py b/colossalai/legacy/pipeline/rpc/__init__.py index 15b65a4138a8..791b9d530673 100644 --- a/colossalai/legacy/pipeline/rpc/__init__.py +++ b/colossalai/legacy/pipeline/rpc/__init__.py @@ -1,4 +1,4 @@ from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine from .utils import pytree_map -__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] +__all__ = ["FillDrainPipelineEngine", "OneFOneBPipelineEngine", "ChimeraPipelineEngine", "pytree_map"] diff --git a/colossalai/legacy/pipeline/rpc/_pipeline_base.py b/colossalai/legacy/pipeline/rpc/_pipeline_base.py index 88ddb9e98eb2..d203e1a11180 100644 --- a/colossalai/legacy/pipeline/rpc/_pipeline_base.py +++ b/colossalai/legacy/pipeline/rpc/_pipeline_base.py @@ -12,17 +12,9 @@ from torch._C._distributed_rpc import PyRRef from torch.futures import Future -from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.legacy.pipeline.middleware import Partition, Topo from colossalai.legacy.pipeline.pipeline_process_group import ppg -from colossalai.legacy.pipeline.rpc.utils import ( - get_batch_lengths, - pyobj_map, - pytree_filter, - pytree_map, - split_batch, - tensor_shape_list, - type_detail, -) +from colossalai.legacy.pipeline.rpc.utils import get_batch_lengths, pyobj_map, pytree_filter, pytree_map, split_batch class Phase(Enum): @@ -33,7 +25,7 @@ class Phase(Enum): class UniqueKey: - __slots__ = ('microbatch_id', 'phase') + __slots__ = ("microbatch_id", "phase") microbatch_id: int phase: Phase @@ -48,12 +40,22 @@ def __hash__(self) -> int: return tuple.__hash__((self.microbatch_id, self.phase)) def __repr__(self) -> str: - return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})' + return f"Key(microbatch_id={self.microbatch_id}, phase={self.phase})" class WorkItem: - __slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id', - 'num_microbatches', 'forward_only') + __slots__ = ( + "stage_id", + "phase", + "args", + "kwargs", + "output", + "refcount", + "microbatch_id", + "batch_id", + "num_microbatches", + "forward_only", + ) stage_id: int phase: Phase @@ -66,50 +68,45 @@ class WorkItem: num_microbatches: int forward_only: bool - def __init__(self, - stage_id, - phase, - args, - kwargs, - output, - microbatch_id, - batch_id, - num_microbatches, - forward_only, - refcount=0) -> None: + def __init__( + self, stage_id, phase, args, kwargs, output, microbatch_id, batch_id, num_microbatches, forward_only, refcount=0 + ) -> None: for attr_name in self.__slots__: setattr(self, attr_name, locals()[attr_name]) class BackwardCache: - __slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs') + __slots__ = ("checkpoint", "stage_input_args", "stage_input_kwargs", "stage_outputs") checkpoint: bool stage_input_args: Tuple[Any] stage_input_kwargs: Dict[Any, Any] stage_outputs: Tuple[Any] - def __init__(self, - stage_input_args: Tuple[Any], - stage_input_kwargs: Dict[Any, Any] = None, - stage_outputs: Tuple[Any] = None, - checkpoint: bool = False) -> None: + def __init__( + self, + stage_input_args: Tuple[Any], + stage_input_kwargs: Dict[Any, Any] = None, + stage_outputs: Tuple[Any] = None, + checkpoint: bool = False, + ) -> None: for arg_name in self.__slots__: setattr(self, arg_name, locals()[arg_name]) class WorkerBase(ABC): - - def __init__(self, - partition_fn: Callable, - partition_args: tuple, - pp_rank: int, - actual_stage_num: int, - num_microbatches: int, - device: str, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: + def __init__( + self, + partition_fn: Callable, + partition_args: tuple, + pp_rank: int, + actual_stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: super().__init__() self.pp_rank = pp_rank @@ -150,11 +147,11 @@ def __init__(self, self._initialize_context_container() # main loop - self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{pp_rank}', daemon=True) + self.main_loop_thread = threading.Thread(target=self._work_loop, name=f"rank_{pp_rank}", daemon=True) self.main_loop_thread.start() def _get_future_by_device(self): - return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) + return torch.futures.Future(devices=None if self.device in (None, "cpu") else [self.device]) def _initialize_outstanding_range(self): outstanding_range = None @@ -199,12 +196,13 @@ def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None): # lifecycle management for DAG scheduler if output_work_item.phase == Phase.FORWARD: lifecycle = len(self.get_consumer_stage_ids()) - if self.is_model_output(): # an extra reference for scheduler collecting results + if self.is_model_output(): # an extra reference for scheduler collecting results lifecycle += 1 elif output_work_item.phase == Phase.BACKWARD: lifecycle = len(self.get_producer_stage_ids()) if self.is_model_input() and self._is_last_step( - output_work_item): # an extra reference for ensure_backward + output_work_item + ): # an extra reference for ensure_backward lifecycle += 1 else: lifecycle = 0 @@ -234,9 +232,9 @@ def sync_global_worker_rrefs(self, pp_rank_to_worker_rref: Dict[int, PyRRef]) -> # offset supports get partial output to reduce comm costs. def get_output_by_key(self, key: UniqueKey, ref_use=False, rank=None, offsets=None) -> Any: output = self._get_output_all(key, ref_use, rank) - if offsets is None: # get all for non iterable output + if offsets is None: # get all for non iterable output return output - else: # get part for iterable output + else: # get part for iterable output output = [output[i] for i in offsets] return output @@ -252,12 +250,12 @@ def get_parameter_gradients(self) -> List[torch.Tensor]: def get_partition(self): with self.partition_condition_lock: - self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition")) return self.module_partition def get_partition_state_dict(self): with self.partition_condition_lock: - self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) + self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition")) return self.module_partition.state_dict() def _make_args_kwargs(self, microbatch, merge=False): @@ -293,8 +291,17 @@ def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bo # make args and kwargs args, kwargs = self._make_args_kwargs(microbatch) - work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, - self.num_microbatches, forward_only) + work_item = WorkItem( + self.pp_rank, + Phase.FORWARD, + args, + kwargs, + output, + microbatch_id, + None, + self.num_microbatches, + forward_only, + ) with self.work_list_condition_lock: self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -314,15 +321,25 @@ def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bo for off in self_input_offsets: self_arg_lst.append(arg_lst[off]) - work_item = WorkItem(self.pp_rank, Phase.FORWARD, self_arg_lst, {}, output, microbatch_id, None, - self.num_microbatches, forward_only) + work_item = WorkItem( + self.pp_rank, + Phase.FORWARD, + self_arg_lst, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + forward_only, + ) with self.work_list_condition_lock: self.work_list[key] = work_item self.work_list_condition_lock.notify_all() # put input tensor which other nodes need into output_list as Phase.INPUT - work_item_remote = WorkItem(self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, - self.num_microbatches, forward_only) + work_item_remote = WorkItem( + self.pp_rank, Phase.INPUT, [], {}, arg_lst, microbatch_id, None, self.num_microbatches, forward_only + ) with self.output_list_condition_lock: self.output_list[recv_input_key] = work_item_remote @@ -343,8 +360,17 @@ def _begin_backward(self, microbatch_id: int): output = self._get_future_by_device() grad_wrt_loss = None - work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, - self.num_microbatches, False) + work_item = WorkItem( + self.pp_rank, + Phase.BACKWARD, + grad_wrt_loss, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + False, + ) self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -367,7 +393,7 @@ def _subscribe_producer(self, microbatch_id: int, forward_only: bool): producer_stage_ids = self.get_producer_stage_ids() producer_num = len(producer_stage_ids) if self.need_model_input(): - producer_num += 1 # for input partition + producer_num += 1 # for input partition subscribe_forward_futures: List[Future] = [None] * producer_num # TODO(jiangziyue) get single value instead of the whole output @@ -376,9 +402,9 @@ def _subscribe_producer(self, microbatch_id: int, forward_only: bool): producer_output_key = UniqueKey(microbatch_id, Phase.INPUT) producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] offsets = self._get_input_offsets_by_index(target_index=0) - subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key, - rank=self.pp_rank, - offsets=offsets) + subscribe_forward_futures[0] = producer_worker_rref.rpc_async().get_output_by_key( + producer_output_key, rank=self.pp_rank, offsets=offsets + ) for i in range(0, producer_num - 1): producer_stage_id = producer_stage_ids[i] @@ -386,11 +412,12 @@ def _subscribe_producer(self, microbatch_id: int, forward_only: bool): producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] target_index = i + 1 offsets = self._get_input_offsets_by_index(target_index=target_index) - if offsets is not None and len(offsets) == 0: # no need to do rpc + if offsets is not None and len(offsets) == 0: # no need to do rpc subscribe_forward_futures[target_index] = [] else: subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( - producer_output_key, rank=self.pp_rank, offsets=offsets) + producer_output_key, rank=self.pp_rank, offsets=offsets + ) else: for i in range(producer_num): @@ -399,14 +426,24 @@ def _subscribe_producer(self, microbatch_id: int, forward_only: bool): producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] target_index = i offsets = self._get_input_offsets_by_index(target_index=target_index) - if offsets is not None and len(offsets) == 0: # no need to do rpc + if offsets is not None and len(offsets) == 0: # no need to do rpc subscribe_forward_futures[target_index] = [] else: subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key( - producer_output_key, rank=self.pp_rank, offsets=offsets) - - work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, - microbatch_id, None, self.num_microbatches, forward_only) + producer_output_key, rank=self.pp_rank, offsets=offsets + ) + + work_item_from_producer = WorkItem( + stage_id, + Phase.FORWARD, + subscribe_forward_futures, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + forward_only, + ) return work_item_from_producer @@ -441,15 +478,25 @@ def _subscribe_consumer(self, microbatch_id: int): consumer_worker_rref = self.pp_rank_to_worker_rref[consumer_stage_id] target_index = i offsets = self._get_output_offsets_by_index(target_index=target_index) - if offsets is not None and len(offsets) == 0: # no need to do rpc + if offsets is not None and len(offsets) == 0: # no need to do rpc subscribe_backward_futures[target_index] = [] else: subscribe_backward_futures[target_index] = consumer_worker_rref.rpc_async().get_output_by_key( - consumer_output_key, rank=self.pp_rank, offsets=offsets) + consumer_output_key, rank=self.pp_rank, offsets=offsets + ) # flatten args - work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, subscribe_backward_futures, {}, output, - microbatch_id, None, self.num_microbatches, False) + work_item_from_consumer = WorkItem( + stage_id, + Phase.BACKWARD, + subscribe_backward_futures, + {}, + output, + microbatch_id, + None, + self.num_microbatches, + False, + ) return work_item_from_consumer @@ -524,8 +571,8 @@ def partition_id_to_pp_rank(self, partition_id: int, topo: Topo): def get_topo(self): with self.partition_condition_lock: - self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) - if hasattr(self.module_partition, '_topo'): + self.partition_condition_lock.wait_for(lambda: hasattr(self, "module_partition")) + if hasattr(self.module_partition, "_topo"): return self.module_partition._topo else: return None @@ -564,12 +611,12 @@ def _get_input_offsets_by_index(self, target_index): if stage_id == src_stage_id: src_index += i break - else: # data from input partition + else: # data from input partition src_index = 0 # when output_len = 1, not iterable if target_index == src_index: if output_len == 1: - res = None # offset = None to get all outputs + res = None # offset = None to get all outputs return res else: res.append(src_offset) @@ -584,7 +631,6 @@ def _get_output_offsets_by_index(self, target_index): consumer_stage_ids = self.get_consumer_stage_ids() for val_list in output_vals: # An output may be passed to many down stages. - target = None for val_pos in val_list.get(): dst_partition_id = val_pos.partition_id dst_offset = val_pos.offset @@ -597,7 +643,7 @@ def _get_output_offsets_by_index(self, target_index): break if target_index == dst_index: if input_len == 1: - res = None # offset = None to get all outputs + res = None # offset = None to get all outputs return res else: res.append(dst_offset) @@ -623,7 +669,7 @@ def _get_real_args_kwargs_fwd(self, args_or_kwargs): flatten_args = [] if self.is_first_stage(): pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) - else: # get by offset + else: # get by offset topo: Topo = self.get_topo() self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) self_partition: Partition = topo.get_partition_by_id(self_partition_id) @@ -652,7 +698,7 @@ def _get_real_args_kwargs_fwd(self, args_or_kwargs): if stage_id == src_stage_id: src_index += i break - else: # data from input partition + else: # data from input partition src_index = 0 # when output_len = 1, not iterable if output_len == 1: @@ -679,7 +725,7 @@ def _get_real_args_kwargs_bwd(self, args_or_kwargs): else: for i, arg in enumerate(args_or_kwargs): args_or_kwargs[i] = arg.wait() - if args_or_kwargs is not None: # get by offset + if args_or_kwargs is not None: # get by offset flatten_args = [] topo: Topo = self.get_topo() self_partition_id = self.pp_rank_to_partition_id(self.pp_rank, topo) @@ -719,7 +765,7 @@ def _get_real_args_kwargs_bwd(self, args_or_kwargs): @abstractmethod def _get_work_item_key(self) -> UniqueKey: """ - this method control the order of the microbatch to consume + this method control the order of the microbatch to consume """ def is_first_stage(self): @@ -761,7 +807,7 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): kwargs = work_item.kwargs microbatch_id = work_item.microbatch_id forward_only = work_item.forward_only - data_process_func = getattr(self, 'data_process_func', self._default_data_process_func) + data_process_func = getattr(self, "data_process_func", self._default_data_process_func) consume_result = None is_first_stage = self.is_first_stage() @@ -787,10 +833,12 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): else: args_kwargs = self._get_real_args_kwargs_fwd(args) - args_kwargs = pyobj_map(args_kwargs, fn=lambda x: x.to(self.device).detach(), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU - args_kwargs = pyobj_map(args_kwargs, fn=lambda x: self.device, - process_types=torch.device) # change devices from last stage to current device + args_kwargs = pyobj_map( + args_kwargs, fn=lambda x: x.to(self.device).detach(), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in GPU + args_kwargs = pyobj_map( + args_kwargs, fn=lambda x: self.device, process_types=torch.device + ) # change devices from last stage to current device args, kwargs = data_process_func(args_kwargs) @@ -851,16 +899,16 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): use_checkpoint = False if not forward_only: - self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args, - stage_input_kwargs, - stage_outputs, - checkpoint=use_checkpoint) - consume_result = pyobj_map(consume_result, fn=lambda x: x.to('cpu'), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in + self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache( + stage_input_args, stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint + ) + consume_result = pyobj_map( + consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in # if not forward_only, do the backward if not forward_only: - if is_last_stage: # if it is the last stage, trigger backward automatic + if is_last_stage: # if it is the last stage, trigger backward automatic self._begin_backward(microbatch_id) elif phase == Phase.BACKWARD: @@ -872,7 +920,9 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): self.backward_times += 1 self.outstanding -= 1 - assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache" + assert ( + microbatch_id in self.microbatch_id_to_backward_cache + ), f"microbatch_id {microbatch_id} not in backward cache" backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) stage_outputs = backward_cache.stage_outputs @@ -906,8 +956,9 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): filtered_grads.append(grad) stage_outputs = filtered_outputs - grad_tensors = pyobj_map(filtered_grads, fn=lambda x: x.to(self.device), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + grad_tensors = pyobj_map( + filtered_grads, fn=lambda x: x.to(self.device), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in GPU autograd.backward(stage_outputs, grad_tensors=grad_tensors) # collect grad of input tensor @@ -920,8 +971,8 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): else: consume_result.append(None) consume_result = pyobj_map( - consume_result, fn=lambda x: x.to('cpu'), - process_types=torch.Tensor) # torch rpc doesn't support args or rets in GPU + consume_result, fn=lambda x: x.to("cpu"), process_types=torch.Tensor + ) # torch rpc doesn't support args or rets in GPU else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -929,7 +980,7 @@ def _consume_work_item_by_phase(self, work_item: WorkItem): return consume_result def _get_store_len(self): - return f'work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}' + return f"work_list:{len(self.work_list)} output_list:{len(self.output_list)} backward_cache:{len(self.microbatch_id_to_backward_cache)} label_cache:{len(self.microbatch_id_to_labels)}" def _get_parameter_grad_sum(self): grad_sum = 0 @@ -1014,19 +1065,20 @@ def step(self): class PipelineEngineBase(ABC, nn.Module): - - def __init__(self, - worker_type, - partition_fn: Callable, - stage_num, - num_microbatches, - device: str, - use_1F1B=False, - chunk: int = 1, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: + def __init__( + self, + worker_type, + partition_fn: Callable, + stage_num, + num_microbatches, + device: str, + use_1F1B=False, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: super().__init__() self.worker_type = worker_type self.partition_fn: Callable = partition_fn @@ -1056,12 +1108,12 @@ def _check_argument(self) -> None: data_process_func = self.data_process_func if data_process_func is not None: assert callable(data_process_func), "data_process_func must be a function" - assert '' not in data_process_func.__repr__(), "data_process_func must be a global function" - assert '' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression" + assert "" not in data_process_func.__repr__(), "data_process_func must be a global function" + assert "" not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression" sig = inspect.signature(data_process_func) - assert len( - sig.parameters - ) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead" + assert ( + len(sig.parameters) == 2 + ), f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead" def _get_actual_stage_num(self) -> int: return self.stage_num if self.chunk == 1 else self.virtual_stage_num @@ -1104,19 +1156,33 @@ def _init_worker(self) -> None: partition_id = self.pp_rank_to_module_partition_id[pp_rank] partition_args = (partition_id, chunk, actual_stage_num) rpc_worker_id = self.pp_rank_to_rpc_worker_id[pp_rank] - if device[:4] == 'cuda': - device = f'cuda:{rpc_worker_id}' - self.pp_rank_to_worker_rref[pp_rank] = rpc.remote(rpc_worker_id, - worker_type, - args=(partition_fn, partition_args, pp_rank, - actual_stage_num, num_microbatches, device, - criterion, metric, checkpoint, data_process_func)) + if device[:4] == "cuda": + device = f"cuda:{rpc_worker_id}" + self.pp_rank_to_worker_rref[pp_rank] = rpc.remote( + rpc_worker_id, + worker_type, + args=( + partition_fn, + partition_args, + pp_rank, + actual_stage_num, + num_microbatches, + device, + criterion, + metric, + checkpoint, + data_process_func, + ), + ) # let each worker know global worker rref (include itself) sync_futs = [] for pp_rank in self.pp_rank_to_worker_rref: - fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs( - self.pp_rank_to_worker_rref) + fut = ( + self.pp_rank_to_worker_rref[pp_rank] + .rpc_async(timeout=0) + .sync_global_worker_rrefs(self.pp_rank_to_worker_rref) + ) sync_futs.append(fut) for fut in sync_futs: @@ -1157,8 +1223,9 @@ def get_input_pp_ranks(self) -> List[int]: def get_output_pp_ranks(self) -> List[int]: return [self._get_actual_stage_num() - 1] - def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], - output_pp_ranks: List[int], ret_future): + def _consume_constraint( + self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future + ): actual_stage_num = self._get_actual_stage_num() use_1F1B = self.use_1F1B if microbatch_id >= actual_stage_num: @@ -1206,7 +1273,8 @@ def _ensure_backward(self, forward_only: bool, input_pp_ranks: List[int]): worker_rref = self.pp_rank_to_worker_rref[pp_rank] key = UniqueKey(self.num_microbatches - 1, Phase.BACKWARD) fut = worker_rref.rpc_async().get_output_by_key( - key, offsets=[]) # only ensure the res exists, no need for real data. + key, offsets=[] + ) # only ensure the res exists, no need for real data. backward_result.append(fut) for fut in backward_result: @@ -1244,11 +1312,14 @@ def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, for if labels is not None and not forward_only: assert hasattr( - self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward" + self, "optimizer_class" + ), "call `initialize_optimizer` to initialize optimizer before forward_backward" num_microbatches = self.num_microbatches - assert batch_length >= num_microbatches, "num_microbatches is greater than the size of a batch, which is illegal" + assert ( + batch_length >= num_microbatches + ), "num_microbatches is greater than the size of a batch, which is illegal" microbatch_size = math.ceil(batch_length / num_microbatches) device = self.device @@ -1285,10 +1356,10 @@ def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, for # collect forward result forward_result = self._collect_forward_result(output_pp_ranks, ret_future) - if not forward_only and hasattr(self, 'optimizer_class'): + if not forward_only and hasattr(self, "optimizer_class"): self.step() - self._reset_worker() # reset worker attributes for next batch + self._reset_worker() # reset worker attributes for next batch return forward_result def initialize_optimizer(self, optimizer_class: type, **kwargs): diff --git a/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py index f53a4835edf2..56da2a954225 100644 --- a/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/legacy/pipeline/rpc/_pipeline_schedule.py @@ -2,7 +2,6 @@ from typing import Callable, Dict, List import torch -import torch.distributed as dist from torch._C._distributed_rpc import PyRRef from torch.futures import Future @@ -15,7 +14,6 @@ class FillDrainWorker(WorkerBase): - def _get_work_item_key(self) -> UniqueKey: # execute backward first (if backward phase in work_list) num_microbatches = self.num_microbatches @@ -33,29 +31,40 @@ def _get_work_item_key(self) -> UniqueKey: class FillDrainPipelineEngine(PipelineEngineBase): - - def __init__(self, - partition_fn: Callable, - stage_num: int, - num_microbatches: int, - device: str, - chunk: int = 1, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: - + def __init__( + self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: if chunk > 1: - assert num_microbatches % stage_num == 0, \ - "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + assert ( + num_microbatches % stage_num == 0 + ), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" use_1F1B = False - super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint, data_process_func) + super().__init__( + FillDrainWorker, + partition_fn, + stage_num, + num_microbatches, + device, + use_1F1B, + chunk, + criterion, + metric, + checkpoint, + data_process_func, + ) class OneFOneBWorker(WorkerBase): - def _get_work_item_key(self) -> UniqueKey: # execute backward first (if backward phase in work_list) pp_rank = self.pp_rank @@ -77,8 +86,7 @@ def _get_work_item_key(self) -> UniqueKey: # change outstanding_range at: # 1. forward times reach actual_stage_num, this is the end of continuous forward # 2. forward times reach num_microbatches, this is the end of 1F1B mode - if not is_last_stage and \ - target_key.phase == Phase.FORWARD: + if not is_last_stage and target_key.phase == Phase.FORWARD: if target_key.microbatch_id == actual_stage_num - 1 and num_microbatches > 2: # Why need num_microbatches > 2 ? Because there is no steady stage when num_microbatches <= 2 outstanding_min = actual_stage_num - pp_rank - 1 @@ -91,30 +99,41 @@ def _get_work_item_key(self) -> UniqueKey: class OneFOneBPipelineEngine(PipelineEngineBase): - - def __init__(self, - partition_fn: Callable, - stage_num: int, - num_microbatches: int, - device: str, - chunk: int = 1, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: - + def __init__( + self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + chunk: int = 1, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: if chunk > 1: - assert num_microbatches % stage_num == 0, \ - "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" + assert ( + num_microbatches % stage_num == 0 + ), "if you use interleaving strategy, make sure 'num_microbatches' is a multiple of stage_num!" # assert num_microbatches > stage_num * chunk, "num_microbatches must be greater than stage_num * chunk" use_1F1B = True - super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint, data_process_func) + super().__init__( + OneFOneBWorker, + partition_fn, + stage_num, + num_microbatches, + device, + use_1F1B, + chunk, + criterion, + metric, + checkpoint, + data_process_func, + ) class ChimeraWorker(WorkerBase): - def _get_producer_consumer(self) -> None: rank = self.pp_rank min_pp_rank = (rank // self.actual_stage_num) * self.actual_stage_num @@ -143,11 +162,12 @@ def _get_work_item_key(self) -> UniqueKey: forward_block_size = 1 if self.num_microbatches < stage_num else self.num_microbatches // stage_num forward_block_num = self.forward_times // forward_block_size - if self.forward_times >= real_microbatch_num or \ - ((pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times): + if self.forward_times >= real_microbatch_num or ( + (pp_rank + 1) % stage_num == 0 and forward_block_num > self.backward_times + ): target_phase = Phase.BACKWARD target_microbatch_id = self.backward_times - else: # others + else: # others target_phase = Phase.FORWARD target_microbatch_id = self.forward_times @@ -168,7 +188,7 @@ def _initialize_partition(self): # from corresponding up stage pp_rank = self.pp_rank stage_num = self.actual_stage_num - device = self.device + self.device if pp_rank < stage_num: super()._initialize_partition() else: @@ -242,27 +262,38 @@ def _hook_before_step(self): class ChimeraPipelineEngine(PipelineEngineBase): - - def __init__(self, - partition_fn: Callable, - stage_num: int, - num_microbatches: int, - device: str, - criterion: Callable = None, - metric: Callable = None, - checkpoint: bool = False, - data_process_func: Callable = None) -> None: - - assert num_microbatches % stage_num == 0, \ - "In Chimera, num_microbatches must be the multiply of stage_num!" + def __init__( + self, + partition_fn: Callable, + stage_num: int, + num_microbatches: int, + device: str, + criterion: Callable = None, + metric: Callable = None, + checkpoint: bool = False, + data_process_func: Callable = None, + ) -> None: + assert num_microbatches % stage_num == 0, "In Chimera, num_microbatches must be the multiply of stage_num!" use_1F1B = False chunk = 1 - super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint, data_process_func) - - def _consume_constraint(self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], - output_pp_ranks: List[int], ret_future): + super().__init__( + ChimeraWorker, + partition_fn, + stage_num, + num_microbatches, + device, + use_1F1B, + chunk, + criterion, + metric, + checkpoint, + data_process_func, + ) + + def _consume_constraint( + self, microbatch_id: int, forward_only: bool, input_pp_ranks: List[int], output_pp_ranks: List[int], ret_future + ): pass def _create_pp_rank_to_rpc_worker_id(self) -> None: diff --git a/colossalai/legacy/pipeline/rpc/utils.py b/colossalai/legacy/pipeline/rpc/utils.py index d1033fbde920..808de301a2a0 100644 --- a/colossalai/legacy/pipeline/rpc/utils.py +++ b/colossalai/legacy/pipeline/rpc/utils.py @@ -1,7 +1,7 @@ import argparse import os import warnings -from typing import Any, Callable, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Tuple, Type, Union import torch import torch.distributed.rpc as rpc @@ -61,7 +61,7 @@ def get_batch_lengths(batch): def split_batch(batch: Any, start, stop, device: str): - if device == 'cuda': + if device == "cuda": fn = lambda x: x[start:stop].cuda() else: fn = lambda x: x[start:stop] @@ -102,8 +102,8 @@ def get_real_args_kwargs(args_or_kwargs): def run_worker(rank, args, master_func): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port device = args.device world_size = args.world_size @@ -112,15 +112,17 @@ def run_worker(rank, args, master_func): num_worker_threads = args.num_worker_threads host = args.master_addr port = args.master_port - backend = 'nccl' if device == 'cuda' else 'gloo' + backend = "nccl" if device == "cuda" else "gloo" launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) + ppg.set_global_info( + rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device, + ) ppg.args = args # in rpc mode, only rank 0 is needed to be coded if rank == 0: @@ -139,17 +141,17 @@ def rpc_run(args, master_func): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--epoch', type=int, default=1) - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--dp_degree', type=int, default=1) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--use_checkpoint', action='store_true') - parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') - parser.add_argument('--num_worker_threads', type=int, default=128) + parser.add_argument("--epoch", type=int, default=1) + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--num_microbatches", type=int, default=2) + parser.add_argument("--chunk", type=int, default=1) + parser.add_argument("--use_checkpoint", action="store_true") + parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "RMSprop"], default="SGD") + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29020") + parser.add_argument("--num_worker_threads", type=int, default=128) return parser.parse_args() diff --git a/colossalai/legacy/pipeline/utils.py b/colossalai/legacy/pipeline/utils.py index be8428692756..182af677c047 100644 --- a/colossalai/legacy/pipeline/utils.py +++ b/colossalai/legacy/pipeline/utils.py @@ -38,8 +38,7 @@ def _binary_partition(weights: List, start: int, end: int): def _heap_addition(weights: List, intervals: int, add_cnt: int): - """ - """ + """ """ def _heap_push(heap, st, ed): value = weights[ed - 1] @@ -113,8 +112,9 @@ def _binary_search(weights, num): def partition_uniform(num_items, pipeline_parallel_size, num_chunks): - assert num_items % num_chunks == 0, \ - "Layer length should be divided by the number of chunks, otherwise parameter method is recommended" + assert ( + num_items % num_chunks == 0 + ), "Layer length should be divided by the number of chunks, otherwise parameter method is recommended" logger = get_dist_logger() parts = [[] for _ in range(pipeline_parallel_size)] @@ -162,7 +162,7 @@ def build_kwargs_for_module(function, input_tensor, kw_dict): elif isinstance(input_tensor, torch.Tensor): kwargs_offset = 1 elif isinstance(input_tensor, (tuple, OrderedDict)): - #assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + # assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' # Huggingface will take their own structures based on OrderedDict as the output # between layers so we've to close this check. kwargs_offset = len(input_tensor) @@ -204,21 +204,21 @@ def foo(attention_mask=None): kwargs[k] = rst return input_tensor if isinstance(input_tensor, tuple): - assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.' + assert len(input_tensor) > 0, f"input_tensor should not be empty, when kw_dict is None." sig = inspect.signature(func) func_args_num = len(sig.parameters) assert func_args_num <= len( - input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.' + input_tensor + ), f"func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}." if func_args_num < len(input_tensor): return func(*input_tensor[:func_args_num]) else: return func(*input_tensor) - assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.' + assert isinstance(input_tensor, torch.Tensor), "input_tensor should be a type of torch.Tensor or tuple." return func(input_tensor) def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): - assert func_key in func_dict, f"{func_key} is not in the function_dict." funcs_to_exec = func_dict[func_key] if isinstance(funcs_to_exec, list): @@ -243,7 +243,7 @@ def call_module(module, args=None, kwargs=None): forward_func = module.forward sig = inspect.signature(forward_func) param_nums = len(sig.parameters) - feed_nums = len(args) + len(kwargs) + len(args) + len(kwargs) args_needed_nums = param_nums - len(kwargs) args_needed = args[:args_needed_nums] if isinstance(module, CheckpointModule): @@ -256,17 +256,17 @@ def call_module(module, args=None, kwargs=None): def customized_partition(exec_seq): - ''' + """ This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an annotation to note the partition point. - ''' + """ customized_parts = {} start = 0 stop = 0 rank = 0 for element in exec_seq: if isinstance(element, str): - if element == 'SPLIT_NODE': + if element == "SPLIT_NODE": customized_parts[rank] = [(start, stop)] start = stop rank += 1 diff --git a/colossalai/legacy/registry/registry.py b/colossalai/legacy/registry/registry.py index 50d6b74c5617..43644f8a9e73 100644 --- a/colossalai/legacy/registry/registry.py +++ b/colossalai/legacy/registry/registry.py @@ -59,7 +59,7 @@ def get_module(self, module_name: str): for lib in self._third_party_lib: if hasattr(lib, module_name): return getattr(lib, module_name) - raise NameError(f'Module {module_name} not found in the registry {self.name}') + raise NameError(f"Module {module_name} not found in the registry {self.name}") def has(self, module_name: str): """Searches for a module with name `module_name` and returns a boolean value indicating diff --git a/colossalai/legacy/tensor/__init__.py b/colossalai/legacy/tensor/__init__.py index d3278bf1e420..a34870eba068 100644 --- a/colossalai/legacy/tensor/__init__.py +++ b/colossalai/legacy/tensor/__init__.py @@ -6,12 +6,12 @@ from .tensor_spec import ColoTensorSpec __all__ = [ - 'ComputePattern', - 'ComputeSpec', - 'distspec', - 'DistSpecManager', - 'ProcessGroup', - 'ColoTensorSpec', - 'ShardSpec', - 'ReplicaSpec', + "ComputePattern", + "ComputeSpec", + "distspec", + "DistSpecManager", + "ProcessGroup", + "ColoTensorSpec", + "ShardSpec", + "ReplicaSpec", ] diff --git a/colossalai/legacy/tensor/compute_spec.py b/colossalai/legacy/tensor/compute_spec.py index 12f8f36bc613..820aafab687f 100644 --- a/colossalai/legacy/tensor/compute_spec.py +++ b/colossalai/legacy/tensor/compute_spec.py @@ -23,7 +23,7 @@ def __init__(self, compute_pattern: ComputePattern) -> None: self.output_replicate = True def __repr__(self): - return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})' + return f"ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})" def set_output_replicate(self, flag: bool = True): self.output_replicate = flag diff --git a/colossalai/legacy/tensor/const.py b/colossalai/legacy/tensor/const.py index 356e8ecc885a..cbc2b29d66a8 100644 --- a/colossalai/legacy/tensor/const.py +++ b/colossalai/legacy/tensor/const.py @@ -3,4 +3,4 @@ class TensorType(Enum): MODEL = 0 - NONMODEL = 1 # mainly activations + NONMODEL = 1 # mainly activations diff --git a/colossalai/legacy/tensor/dist_spec_mgr.py b/colossalai/legacy/tensor/dist_spec_mgr.py index d97308b04bef..3942b5b7a33c 100644 --- a/colossalai/legacy/tensor/dist_spec_mgr.py +++ b/colossalai/legacy/tensor/dist_spec_mgr.py @@ -20,14 +20,12 @@ def divide(numerator, denominator): Returns: int: the result of exact division. """ - assert denominator != 0, 'denominator can not be zero' - assert numerator % denominator == 0, \ - '{} is not divisible by {}'.format(numerator, denominator) + assert denominator != 0, "denominator can not be zero" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) return numerator // denominator class TransformDistSpec(torch.autograd.Function): - @staticmethod def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func): ctx.old_dist_spec = old_dist_spec @@ -38,12 +36,17 @@ def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backw @staticmethod def backward(ctx, grad_outputs): - return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, - ctx.pg), None, None, None, None, None + return ( + ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, ctx.pg), + None, + None, + None, + None, + None, + ) class DistSpecManager: - _use_autograd_function: bool = True @staticmethod @@ -51,8 +54,9 @@ def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None: pass @staticmethod - def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, - pg: ProcessGroup) -> torch.Tensor: + def _shard_as( + tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup + ) -> torch.Tensor: """_shard_as: shard the tensor w.r.t a distributed specification. Assuming the tensor passed in is a global (replicated) tensor. Args: @@ -62,7 +66,9 @@ def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSp Returns: torch.Tensor: a torch tensor after sharded. """ - assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!" + assert ( + old_dist_spec.placement.value == "r" + ), f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!" DistSpecManager._sanity_check(old_dist_spec, dist_spec) chunk = tensor @@ -86,9 +92,9 @@ def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> Returns: torch.Tensor: a replicated tensor. """ - assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!" + assert old_dist_spec.placement.value == "s", f"The old_dist_spec of DistSpecManager._gather must be SHARD!" is_cpu_tensor = False - if tensor.device.type == 'cpu': + if tensor.device.type == "cpu": # pytorch lower than 1.11 dose not support gather a cpu tensor. # Therefore, we transfer tensor to GPU before gather. saved_dev = tensor.device @@ -96,14 +102,14 @@ def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> is_cpu_tensor = True buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())] - assert tensor.device.type == 'cuda' + assert tensor.device.type == "cuda" dist.all_gather(buffer, tensor, group=pg.tp_process_group()) for i in range(len(old_dist_spec.dims) - 1, -1, -1): new_buffer = [] dim = old_dist_spec.dims[i] num_parts = old_dist_spec.num_partitions[i] for start in range(0, len(buffer), num_parts): - new_buffer.append(torch.cat(buffer[start:start + num_parts], dim)) + new_buffer.append(torch.cat(buffer[start : start + num_parts], dim)) buffer = new_buffer assert len(buffer) == 1 @@ -112,15 +118,17 @@ def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> return buffer[0] @staticmethod - def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, - pg: ProcessGroup) -> torch.Tensor: + def _all_to_all( + tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup + ) -> torch.Tensor: world_size = pg.tp_world_size() if world_size == 1: return tensor - assert tensor.device.type == "cuda", \ - "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \ + assert tensor.device.type == "cuda", ( + "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " f"collective function, however, we got {tensor.device.type} device" + ) gather_dim = old_dist_spec.dims[0] scatter_dim = dist_spec.dims[0] @@ -164,8 +172,9 @@ def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, p return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg) @staticmethod - def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, - pg: ProcessGroup) -> torch.Tensor: + def handle_trans_spec( + tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup + ) -> torch.Tensor: assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec" assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec" @@ -174,7 +183,7 @@ def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: (DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r, (DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s, (DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r, - (DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s + (DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s, } forward_trans_handle = trans_funcs[trans_func_key] @@ -183,8 +192,9 @@ def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)] - return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, - backward_trans_handle) + return TransformDistSpec.apply( + tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, backward_trans_handle + ) @staticmethod @contextmanager diff --git a/colossalai/legacy/tensor/distspec.py b/colossalai/legacy/tensor/distspec.py index 3a09f1426e31..efef9904ec10 100644 --- a/colossalai/legacy/tensor/distspec.py +++ b/colossalai/legacy/tensor/distspec.py @@ -1,12 +1,12 @@ from enum import Enum from typing import List -__all__ = ['ReplicaSpec', 'ShardSpec'] +__all__ = ["ReplicaSpec", "ShardSpec"] class DistPlacementPattern(Enum): - REPLICATE = 'r' - SHARD = 's' + REPLICATE = "r" + SHARD = "s" class _DistSpec: @@ -25,7 +25,6 @@ class _DistSpec: """ def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info): - self.placement = dist_placement_pattern for k, v in meta_info.items(): setattr(self, k, v) @@ -34,15 +33,15 @@ def __eq__(self, other: "_DistSpec") -> bool: if dir(self) != dir(other): return False for attr in dir(self): - if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr): + if not attr.startswith("__") and getattr(self, attr) != getattr(other, attr): return False return True def __repr__(self) -> str: attr_list = [] for attr in dir(self): - if not attr.startswith('__'): - attr_list.append(f'{attr}={str(getattr(self, attr))}') + if not attr.startswith("__"): + attr_list.append(f"{attr}={str(getattr(self, attr))}") attr_str = ", ".join(attr_list) return "DistSpec(" + attr_str + ")" diff --git a/colossalai/legacy/tensor/process_group.py b/colossalai/legacy/tensor/process_group.py index 8d2e9a616d76..ec6043163336 100644 --- a/colossalai/legacy/tensor/process_group.py +++ b/colossalai/legacy/tensor/process_group.py @@ -7,13 +7,12 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta): - def __init__(self): # distributed settings # use this dict to record all Pytorch ProcessGroups self.dict = {} # set a distributed logger - self.logger = get_dist_logger('ProcessGroup') + self.logger = get_dist_logger("ProcessGroup") def log_pg_init(self, rank_list: List[int], backend: str): str_list = ["Pytorch ProcessGroup Init:"] @@ -21,9 +20,8 @@ def log_pg_init(self, rank_list: List[int], backend: str): str_list.append(f"ranks: {rank_list}") self.logger.info("\n\t".join(str_list), ranks=[0]) - def get(self, rank_list: List[int], backend: str = 'nccl'): - """Reuse Pytorch ProcessGroup when such a group is initialized - """ + def get(self, rank_list: List[int], backend: str = "nccl"): + """Reuse Pytorch ProcessGroup when such a group is initialized""" # we need to convert the passed list to a tuple # since List is unhashable processgroup_key = (backend, tuple(rank_list)) @@ -51,11 +49,13 @@ class ProcessGroup: dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks). """ - def __init__(self, - rank: Optional[int] = None, - ranks: Optional[List[int]] = None, - tp_degree: Optional[int] = None, - dp_degree: Optional[int] = None) -> None: + def __init__( + self, + rank: Optional[int] = None, + ranks: Optional[List[int]] = None, + tp_degree: Optional[int] = None, + dp_degree: Optional[int] = None, + ) -> None: if not torch.distributed.is_initialized(): self.is_init = False return @@ -64,13 +64,13 @@ def __init__(self, self._rank = torch.distributed.get_rank() if rank is not None: - assert self._rank == rank # make sure that the global rank is correct + assert self._rank == rank # make sure that the global rank is correct if ranks is None: self._rank_list = list(range(torch.distributed.get_world_size())) else: self._rank_list = ranks - self._rank_list.sort() # ensure that the list is in order + self._rank_list.sort() # ensure that the list is in order self._world_size = len(self._rank_list) @@ -79,31 +79,36 @@ def __init__(self, self._tp_degree = 1 elif dp_degree and not tp_degree: self._dp_degree = dp_degree - assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" + assert ( + self._world_size % self._dp_degree == 0 + ), f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" self._tp_degree = self._world_size // dp_degree elif not dp_degree and tp_degree: self._tp_degree = tp_degree - assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" + assert ( + self._world_size % self._tp_degree == 0 + ), f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" self._dp_degree = self._world_size // tp_degree else: self._dp_degree = dp_degree self._tp_degree = tp_degree - assert self._dp_degree * self._tp_degree == self._world_size, \ - f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \ + assert self._dp_degree * self._tp_degree == self._world_size, ( + f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" f"and TP degree {self._tp_degree}" + ) self._tp_rank_list = None self._dp_rank_list = None for i in range(self._dp_degree): i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] - PYTORCHPGDICT_.get(i_tp_list, 'nccl') + PYTORCHPGDICT_.get(i_tp_list, "nccl") if self._rank in i_tp_list: self._tp_rank_list = i_tp_list for j in range(self._tp_degree): j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] - PYTORCHPGDICT_.get(j_dp_list, 'nccl') + PYTORCHPGDICT_.get(j_dp_list, "nccl") if self._rank in j_dp_list: self._dp_rank_list = j_dp_list @@ -119,11 +124,11 @@ def set_cpu_groups(self): for i in range(self._dp_degree): i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)] - PYTORCHPGDICT_.get(i_tp_list, 'gloo') + PYTORCHPGDICT_.get(i_tp_list, "gloo") for j in range(self._tp_degree): j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)] - PYTORCHPGDICT_.get(j_dp_list, 'gloo') + PYTORCHPGDICT_.get(j_dp_list, "gloo") self._has_cpu_groups = True @@ -145,7 +150,7 @@ def __repr__(self): else: return "ProcessGroup not initialized" - def __eq__(self, obj: 'ProcessGroup') -> bool: + def __eq__(self, obj: "ProcessGroup") -> bool: if not isinstance(obj, ProcessGroup): return False if self._rank != obj._rank: @@ -260,7 +265,7 @@ def dp_process_group(self): Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. """ - return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') + return PYTORCHPGDICT_.get(self._dp_rank_list, "nccl") def tp_process_group(self): """tp_process_group @@ -270,7 +275,7 @@ def tp_process_group(self): Returns: `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. """ - return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') + return PYTORCHPGDICT_.get(self._tp_rank_list, "nccl") def cpu_dp_process_group(self): """cpu_dp_process_group @@ -283,7 +288,7 @@ def cpu_dp_process_group(self): `torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group. """ assert self._has_cpu_groups - return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') + return PYTORCHPGDICT_.get(self._dp_rank_list, "gloo") def cpu_tp_process_group(self): """cpu_tp_process_group @@ -296,7 +301,7 @@ def cpu_tp_process_group(self): `torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group. """ assert self._has_cpu_groups - return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') + return PYTORCHPGDICT_.get(self._tp_rank_list, "gloo") def get_ranks_in_dp(self) -> List[int]: """get_ranks_in_dp diff --git a/colossalai/legacy/tensor/tensor_spec.py b/colossalai/legacy/tensor/tensor_spec.py index aa792e507639..5bdd384e5e15 100644 --- a/colossalai/legacy/tensor/tensor_spec.py +++ b/colossalai/legacy/tensor/tensor_spec.py @@ -9,12 +9,13 @@ @dataclass class ColoTensorSpec: - """ ColoTensorSpec + """ColoTensorSpec A data class for specifications of the `ColoTensor`. It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`. The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`. """ + pg: ProcessGroup dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE) compute_attr: Optional[ComputeSpec] = None diff --git a/colossalai/legacy/trainer/__init__.py b/colossalai/legacy/trainer/__init__.py index 84e53dc4e87a..e4fddc7c1c9f 100644 --- a/colossalai/legacy/trainer/__init__.py +++ b/colossalai/legacy/trainer/__init__.py @@ -1,3 +1,3 @@ from ._trainer import Trainer -__all__ = ['Trainer'] +__all__ = ["Trainer"] diff --git a/colossalai/legacy/trainer/_trainer.py b/colossalai/legacy/trainer/_trainer.py index 1cb99fcc90ed..46e708622237 100644 --- a/colossalai/legacy/trainer/_trainer.py +++ b/colossalai/legacy/trainer/_trainer.py @@ -151,7 +151,7 @@ def _call_hooks(self, func, output=None): @staticmethod def _should_display_progress(display_progress: bool): """Only display progress on DP rank 0, TP rank 0 and PP last rank""" - return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()) + return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() def _train_epoch( self, @@ -293,8 +293,7 @@ def fit( assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}" for hook in hooks: - assert isinstance(hook, BaseHook), \ - f'expected the hook to be of type BaseHook, but got {type(hook)}' + assert isinstance(hook, BaseHook), f"expected the hook to be of type BaseHook, but got {type(hook)}" else: hooks = [] self.hooks = hooks diff --git a/colossalai/legacy/trainer/hooks/__init__.py b/colossalai/legacy/trainer/hooks/__init__.py index bf9cc6421b67..290aeb64a04d 100644 --- a/colossalai/legacy/trainer/hooks/__init__.py +++ b/colossalai/legacy/trainer/hooks/__init__.py @@ -11,7 +11,16 @@ from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook __all__ = [ - 'BaseHook', 'MetricHook', 'LossHook', 'AccuracyHook', 'LogMetricByEpochHook', 'TensorboardHook', - 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook', 'ThroughputHook', 'LogMetricByStepHook', - 'SaveCheckpointHook' + "BaseHook", + "MetricHook", + "LossHook", + "AccuracyHook", + "LogMetricByEpochHook", + "TensorboardHook", + "LogTimingByEpochHook", + "LogMemoryByEpochHook", + "LRSchedulerHook", + "ThroughputHook", + "LogMetricByStepHook", + "SaveCheckpointHook", ] diff --git a/colossalai/legacy/trainer/hooks/_base_hook.py b/colossalai/legacy/trainer/hooks/_base_hook.py index cca8e081ec88..fc883134203f 100644 --- a/colossalai/legacy/trainer/hooks/_base_hook.py +++ b/colossalai/legacy/trainer/hooks/_base_hook.py @@ -18,24 +18,16 @@ def __init__(self, priority: int) -> None: self.priority = priority def after_hook_is_attached(self, trainer): - """Actions after hooks are attached to trainer. - """ - pass + """Actions after hooks are attached to trainer.""" def before_train(self, trainer): - """Actions before training. - """ - pass + """Actions before training.""" def after_train(self, trainer): - """Actions after training. - """ - pass + """Actions after training.""" def before_train_iter(self, trainer): - """Actions before running a training iteration. - """ - pass + """Actions before running a training iteration.""" def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): """Actions after running a training iteration. @@ -46,42 +38,27 @@ def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor) label (:class:`torch.Tensor`): Labels of the input data. loss (:class:`torch.Tensor`): Loss between the output and input data. """ - pass def before_train_epoch(self, trainer): - """Actions before starting a training epoch. - """ - pass + """Actions before starting a training epoch.""" def after_train_epoch(self, trainer): - """Actions after finishing a training epoch. - """ - pass + """Actions after finishing a training epoch.""" def before_test(self, trainer): - """Actions before evaluation. - """ - pass + """Actions before evaluation.""" def after_test(self, trainer): - """Actions after evaluation. - """ - pass + """Actions after evaluation.""" def before_test_epoch(self, trainer): - """Actions before starting a testing epoch. - """ - pass + """Actions before starting a testing epoch.""" def after_test_epoch(self, trainer): - """Actions after finishing a testing epoch. - """ - pass + """Actions after finishing a testing epoch.""" def before_test_iter(self, trainer): - """Actions before running a testing iteration. - """ - pass + """Actions before running a testing iteration.""" def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): """Actions after running a testing iteration. @@ -92,7 +69,6 @@ def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): label (:class:`torch.Tensor`): Labels of the input data loss (:class:`torch.Tensor`): Loss between the output and input data """ - pass def init_runner_states(self, trainer, key, val): """Initializes trainer's state. diff --git a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py index cda10030bf65..50c80759867e 100644 --- a/colossalai/legacy/trainer/hooks/_checkpoint_hook.py +++ b/colossalai/legacy/trainer/hooks/_checkpoint_hook.py @@ -27,12 +27,14 @@ class SaveCheckpointHook(BaseHook): depend on the hooks order in the hook list. """ - def __init__(self, - interval: int = 1, - checkpoint_dir: str = None, - model: torch.nn.Module = None, - save_by_iter: bool = False, - priority: int = 10): + def __init__( + self, + interval: int = 1, + checkpoint_dir: str = None, + model: torch.nn.Module = None, + save_by_iter: bool = False, + priority: int = 10, + ): super().__init__(priority=priority) self.interval = interval self.checkpoint_dir = checkpoint_dir @@ -52,22 +54,23 @@ def after_hook_is_attached(self, trainer): self.model = self.model if self.model is not None else trainer.engine.model def after_train_iter(self, trainer, output, label, loss): - """Saves the model after a training iter. - """ + """Saves the model after a training iter.""" # save by interval if self.save_by_iter and trainer.cur_step % self.interval == 0: - save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, - self._lr_scheduler) - self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', - ranks=[0]) + save_checkpoint( + self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler + ) + self.logger.info( + f"checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}", ranks=[0] + ) else: pass def after_train_epoch(self, trainer): - """Saves the model after a training epoch. - """ + """Saves the model after a training epoch.""" # save by interval if trainer.cur_epoch % self.interval == 0: - save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, - self._lr_scheduler) - self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0]) + save_checkpoint( + self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler + ) + self.logger.info(f"checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}", ranks=[0]) diff --git a/colossalai/legacy/trainer/hooks/_commons_.py b/colossalai/legacy/trainer/hooks/_commons_.py index 4923b8cba6c0..18da38298704 100644 --- a/colossalai/legacy/trainer/hooks/_commons_.py +++ b/colossalai/legacy/trainer/hooks/_commons_.py @@ -3,7 +3,7 @@ def _format_number(val, prec=5): if isinstance(val, float): - return f'{val:.{prec}g}' + return f"{val:.{prec}g}" elif torch.is_tensor(val) and torch.is_floating_point(val): - return f'{val.item():.{prec}g}' + return f"{val.item():.{prec}g}" return val diff --git a/colossalai/legacy/trainer/hooks/_log_hook.py b/colossalai/legacy/trainer/hooks/_log_hook.py index b1a398ce7f71..c1cf0ca5228b 100644 --- a/colossalai/legacy/trainer/hooks/_log_hook.py +++ b/colossalai/legacy/trainer/hooks/_log_hook.py @@ -51,20 +51,20 @@ def __init__(self, priority: int = 10): super().__init__(priority) def after_train_iter(self, trainer, *args): - trainer.states['step_metrics'] = dict() - for metric_name, metric_calculator in trainer.states['metrics']['train'].items(): + trainer.states["step_metrics"] = dict() + for metric_name, metric_calculator in trainer.states["metrics"]["train"].items(): if isinstance(metric_calculator, ThroughputMetric): - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_info() else: - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_value() def after_test_iter(self, trainer, *args): - trainer.states['step_metrics'] = dict() - for metric_name, metric_calculator in trainer.states['metrics']['test'].items(): + trainer.states["step_metrics"] = dict() + for metric_name, metric_calculator in trainer.states["metrics"]["test"].items(): if isinstance(metric_calculator, ThroughputMetric): - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info() + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_info() else: - trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() + trainer.states["step_metrics"][metric_name.lower()] = metric_calculator.get_last_step_value() @HOOKS.register_module @@ -85,24 +85,24 @@ def __init__(self, logger, interval: int = 1, priority: int = 10) -> None: def _get_str(self, trainer, mode): msg = [] - for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): - msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}') - msg = ' | '.join(msg) + for metric_name, metric_calculator in trainer.states["metrics"][mode].items(): + msg.append(f"{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}") + msg = " | ".join(msg) return msg def after_train_epoch(self, trainer): if self._is_epoch_to_log(trainer): - msg = self._get_str(trainer=trainer, mode='train') + msg = self._get_str(trainer=trainer, mode="train") if self._is_rank_to_log: - self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}') + self.logger.info(f"[Epoch {trainer.cur_epoch} / Train]: {msg}") # f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') def after_test_epoch(self, trainer): if self._is_epoch_to_log(trainer): - msg = self._get_str(trainer=trainer, mode='test') + msg = self._get_str(trainer=trainer, mode="test") if self._is_rank_to_log: - self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') + self.logger.info(f"[Epoch {trainer.cur_epoch} / Test]: {msg}") # f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}') @@ -145,8 +145,11 @@ def __init__( self._is_valid_rank_to_log = True # check for - if gpc.is_initialized(ParallelMode.PIPELINE) and \ - not gpc.is_last_rank(ParallelMode.PIPELINE) and self._is_valid_rank_to_log: + if ( + gpc.is_initialized(ParallelMode.PIPELINE) + and not gpc.is_last_rank(ParallelMode.PIPELINE) + and self._is_valid_rank_to_log + ): raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group") if self._is_valid_rank_to_log: @@ -157,38 +160,38 @@ def __init__( rank = 0 # create workspace - log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}') + log_dir = osp.join(log_dir, f"{parallel_mode}_rank_{rank}") os.makedirs(log_dir, exist_ok=True) - self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}') + self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f"_rank_{rank}") def _log_by_iter(self, trainer, mode: str): - for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + for metric_name, metric_calculator in trainer.states["metrics"][mode].items(): if metric_calculator.epoch_only: continue val = metric_calculator.get_last_step_value() if self._is_valid_rank_to_log: - self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) + self.writer.add_scalar(f"{metric_name}/{mode}", val, trainer.cur_step) def _log_by_epoch(self, trainer, mode: str): - for metric_name, metric_calculator in trainer.states['metrics'][mode].items(): + for metric_name, metric_calculator in trainer.states["metrics"][mode].items(): if metric_calculator.epoch_only: val = metric_calculator.get_accumulated_value() if self._is_valid_rank_to_log: - self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step) + self.writer.add_scalar(f"{metric_name}/{mode}", val, trainer.cur_step) def after_test_iter(self, trainer, *args): - self._log_by_iter(trainer, mode='test') + self._log_by_iter(trainer, mode="test") def after_test_epoch(self, trainer): - self._log_by_epoch(trainer, mode='test') + self._log_by_epoch(trainer, mode="test") def after_train_iter(self, trainer, *args): - self._log_by_iter(trainer, mode='train') + self._log_by_iter(trainer, mode="train") def after_train_epoch(self, trainer): - self._log_by_epoch(trainer, mode='train') + self._log_by_epoch(trainer, mode="train") @HOOKS.register_module @@ -206,13 +209,15 @@ class LogTimingByEpochHook(LogByEpochHook): ignore_num_train_steps (int, optional): Number of training steps to ignore, defaults to 0. """ - def __init__(self, - timer: MultiTimer, - logger: DistributedLogger, - interval: int = 1, - priority: int = 10, - log_eval: bool = True, - ignore_num_train_steps: int = 0) -> None: + def __init__( + self, + timer: MultiTimer, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + ignore_num_train_steps: int = 0, + ) -> None: super().__init__(logger=logger, interval=interval, priority=priority) self._timer = timer self._log_eval = log_eval @@ -229,33 +234,31 @@ def _get_message(self, mode): if timer_name.startswith(mode): last_elapsed_time = timer.get_elapsed_time() if timer.has_history: - if timer_name == 'Train-step' and not self._is_train_step_history_trimmed: - timer._history = timer._history[self._ignore_num_train_steps:] + if timer_name == "Train-step" and not self._is_train_step_history_trimmed: + timer._history = timer._history[self._ignore_num_train_steps :] self._is_train_step_history_trimmed = True history_mean = timer.get_history_mean() - history_sum = timer.get_history_sum() + timer.get_history_sum() msg.append( - f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s' + f"{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s" ) else: - msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s') + msg.append(f"{timer_name}: last = {_format_number(last_elapsed_time)} s") - msg = ' | '.join(msg) + msg = " | ".join(msg) return msg def after_train_epoch(self, trainer): - """Writes log after finishing a training epoch. - """ + """Writes log after finishing a training epoch.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log: - msg = self._get_message('Train') - self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}') + msg = self._get_message("Train") + self.logger.info(f"[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}") def after_test_epoch(self, trainer): - """Writes log after finishing a testing epoch. - """ + """Writes log after finishing a testing epoch.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: - msg = self._get_message('Test') - self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}') + msg = self._get_message("Test") + self.logger.info(f"[Epoch {trainer.cur_epoch} / Test]: {msg}") @HOOKS.register_module @@ -272,31 +275,28 @@ class LogMemoryByEpochHook(LogByEpochHook): """ def __init__( - self, - logger: DistributedLogger, - interval: int = 1, - priority: int = 10, - log_eval: bool = True, - report_cpu: bool = False, # no reference + self, + logger: DistributedLogger, + interval: int = 1, + priority: int = 10, + log_eval: bool = True, + report_cpu: bool = False, # no reference ) -> None: super().__init__(logger=logger, interval=interval, priority=priority) self._log_eval = log_eval self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() def before_train(self, trainer): - """Resets before training. - """ + """Resets before training.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log: - report_memory_usage('Before-train', self.logger) + report_memory_usage("Before-train", self.logger) def after_train_epoch(self, trainer): - """Writes log after finishing a training epoch. - """ + """Writes log after finishing a training epoch.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log: - report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger) + report_memory_usage(f"[Epoch {trainer.cur_epoch} / Train]", self.logger) def after_test(self, trainer): - """Reports after testing. - """ + """Reports after testing.""" if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval: - report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger) + report_memory_usage(f"[Epoch {trainer.cur_epoch} / Test]", self.logger) diff --git a/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py index 6d60966da12a..d14db563473c 100644 --- a/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py +++ b/colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py @@ -34,15 +34,16 @@ def __init__( def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) - trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch, - initial_lr=self.lr_scheduler.get_last_lr()[0]) + trainer.states["metrics"]["train"]["LR"] = LearningRateMetric( + epoch_only=self.by_epoch, initial_lr=self.lr_scheduler.get_last_lr()[0] + ) def after_train_epoch(self, trainer): if self.by_epoch: self.lr_scheduler.step() - trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) + trainer.states["metrics"]["train"]["LR"].update(self.lr_scheduler.get_last_lr()[0]) def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): if not self.by_epoch: self.lr_scheduler.step() - trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0]) + trainer.states["metrics"]["train"]["LR"].update(self.lr_scheduler.get_last_lr()[0]) diff --git a/colossalai/legacy/trainer/hooks/_metric_hook.py b/colossalai/legacy/trainer/hooks/_metric_hook.py index 899e4d08a5c9..35a7f0a156ab 100644 --- a/colossalai/legacy/trainer/hooks/_metric_hook.py +++ b/colossalai/legacy/trainer/hooks/_metric_hook.py @@ -35,8 +35,7 @@ def __init__(self, epoch_only: bool): @property def epoch_only(self): - """Returns :attr:`epoch_only`. - """ + """Returns :attr:`epoch_only`.""" return self._epoch_only @abstractmethod @@ -44,20 +43,16 @@ def reset(self) -> None: """Resets the metric to it's initial state. By default, this is called at the start of each epoch. """ - pass @abstractmethod def update(self, *args, **kwargs) -> None: """Updates the metric's state using the passed batch output. By default, this is called once for each batch. """ - pass @abstractmethod def get_last_step_value(self) -> float: - """Returns the metric value in the last iteration. - """ - pass + """Returns the metric value in the last iteration.""" @abstractmethod def get_accumulated_value(self): @@ -67,7 +62,6 @@ def get_accumulated_value(self): :return: the actual quantity of interest :rtype: Any """ - pass @staticmethod @abstractmethod @@ -77,7 +71,6 @@ def is_better(a, b) -> bool: :return: The result of comparison :rtype: bool """ - pass class LossMetric(Metric): @@ -94,8 +87,7 @@ def __init__(self, epoch_only): self.count = 0 def reset(self) -> None: - """Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero. - """ + """Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.""" self.last_step_loss.zero_() self.accum_loss.zero_() self.count = 0 @@ -114,8 +106,7 @@ def update(self, loss) -> None: self.count += 1 def get_accumulated_value(self): - """Returns accumulated loss. - """ + """Returns accumulated loss.""" if gpc.is_initialized(ParallelMode.DATA): dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA)) self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA)) @@ -124,8 +115,7 @@ def get_accumulated_value(self): return self.accum_loss.item() def get_last_step_value(self) -> float: - """Returns :attr:`last_step_loss`. - """ + """Returns :attr:`last_step_loss`.""" return self.last_step_loss.cpu().item() @staticmethod @@ -141,7 +131,7 @@ class LearningRateMetric(Metric): initial_lr (float, optional): Initial learning rate, defaults to 0.0. """ - def __init__(self, epoch_only: bool, initial_lr: float = 0.): + def __init__(self, epoch_only: bool, initial_lr: float = 0.0): super().__init__(epoch_only=epoch_only) self.lr = initial_lr @@ -241,8 +231,8 @@ def __init__( self._is_stage_to_compute = is_no_pp_or_last_stage() def _check_metric_states_initialization(self, trainer): - if 'metrics' not in trainer.states: - self.init_runner_states(trainer, 'metrics', dict(train={}, test={})) + if "metrics" not in trainer.states: + self.init_runner_states(trainer, "metrics", dict(train={}, test={})) @HOOKS.register_module @@ -266,8 +256,8 @@ def after_hook_is_attached(self, trainer): self.test_loss = LossMetric(epoch_only=True) # register the metric calculator - trainer.states['metrics']['train']['Loss'] = self.train_loss - trainer.states['metrics']['test']['Loss'] = self.test_loss + trainer.states["metrics"]["train"]["Loss"] = self.train_loss + trainer.states["metrics"]["test"]["Loss"] = self.test_loss def before_train_epoch(self, trainer): if self._is_stage_to_compute: @@ -307,7 +297,7 @@ def after_hook_is_attached(self, trainer): self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func) # register the metric - trainer.states['metrics']['test']['Accuracy'] = self.metric + trainer.states["metrics"]["test"]["Accuracy"] = self.metric def before_test(self, trainer): if self._is_stage_to_compute: @@ -356,8 +346,9 @@ def get_last_step_value(self) -> float: if self._use_local: self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) else: - self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / gpc.get_world_size( + ParallelMode.DATA + ) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) @@ -367,8 +358,9 @@ def get_last_step_info(self) -> str: if self._use_local: self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) else: - self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) + self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / gpc.get_world_size( + ParallelMode.DATA + ) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) @@ -379,8 +371,9 @@ def get_last_step_info(self) -> str: return f"{sample_per_sec} sample_per_sec" def get_accumulated_value(self) -> float: - self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \ - gpc.get_world_size(ParallelMode.DATA) + self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / gpc.get_world_size( + ParallelMode.DATA + ) self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA) return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item() @@ -411,14 +404,16 @@ def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: i def after_hook_is_attached(self, trainer): self._check_metric_states_initialization(trainer) if self._is_stage_to_compute: - self.metric = ThroughputMetric(epoch_only=True, - ignored_steps=self.ignored_steps, - tflop_per_step=self._tflop_per_step, - use_local=self._use_local) + self.metric = ThroughputMetric( + epoch_only=True, + ignored_steps=self.ignored_steps, + tflop_per_step=self._tflop_per_step, + use_local=self._use_local, + ) # register the metric - trainer.states['metrics']['train']['Throughput'] = self.metric - trainer.states['metrics']['test']['Throughput'] = self.metric + trainer.states["metrics"]["train"]["Throughput"] = self.metric + trainer.states["metrics"]["test"]["Throughput"] = self.metric def before_train_epoch(self, trainer): if self._is_stage_to_compute: @@ -426,8 +421,9 @@ def before_train_epoch(self, trainer): def after_train_iter(self, trainer, *args): if self._is_stage_to_compute: - self.metric.update(trainer.engine.schedule.batch_size, - trainer._timer.get_timer('Train-step').get_elapsed_time()) + self.metric.update( + trainer.engine.schedule.batch_size, trainer._timer.get_timer("Train-step").get_elapsed_time() + ) def before_test(self, trainer): if self._is_stage_to_compute: @@ -435,5 +431,6 @@ def before_test(self, trainer): def after_test_iter(self, trainer, *args): if self._is_stage_to_compute: - self.metric.update(trainer.engine.schedule.batch_size, - trainer._timer.get_timer('Test-step').get_elapsed_time()) + self.metric.update( + trainer.engine.schedule.batch_size, trainer._timer.get_timer("Test-step").get_elapsed_time() + ) diff --git a/colossalai/legacy/utils/__init__.py b/colossalai/legacy/utils/__init__.py index ae358f8bebcb..86984edeec65 100644 --- a/colossalai/legacy/utils/__init__.py +++ b/colossalai/legacy/utils/__init__.py @@ -26,28 +26,28 @@ ) __all__ = [ - 'DataParallelSampler', - 'get_dataloader', - 'save_checkpoint', - 'load_checkpoint', - 'colo_device_memory_capacity', - 'colo_device_memory_used', - 'colo_get_cpu_memory_capacity', - 'colo_set_cpu_memory_capacity', - 'colo_set_process_memory_fraction', - 'report_memory_usage', - 'clip_grad_norm_fp32', - 'copy_tensor_parallel_attributes', - 'count_zeros_fp32', - 'is_dp_rank_0', - 'is_model_parallel_parameter', - 'is_no_pp_or_last_stage', - 'is_tp_rank_0', - 'is_using_ddp', - 'is_using_pp', - 'is_using_sequence', - 'param_is_not_tensor_parallel_duplicate', - 'print_rank_0', - 'switch_virtual_pipeline_parallel_rank', - 'sync_model_param', + "DataParallelSampler", + "get_dataloader", + "save_checkpoint", + "load_checkpoint", + "colo_device_memory_capacity", + "colo_device_memory_used", + "colo_get_cpu_memory_capacity", + "colo_set_cpu_memory_capacity", + "colo_set_process_memory_fraction", + "report_memory_usage", + "clip_grad_norm_fp32", + "copy_tensor_parallel_attributes", + "count_zeros_fp32", + "is_dp_rank_0", + "is_model_parallel_parameter", + "is_no_pp_or_last_stage", + "is_tp_rank_0", + "is_using_ddp", + "is_using_pp", + "is_using_sequence", + "param_is_not_tensor_parallel_duplicate", + "print_rank_0", + "switch_virtual_pipeline_parallel_rank", + "sync_model_param", ] diff --git a/colossalai/legacy/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py index add690f28cc0..387e1c54ec87 100644 --- a/colossalai/legacy/utils/activation_checkpoint.py +++ b/colossalai/legacy/utils/activation_checkpoint.py @@ -28,7 +28,6 @@ def copy_to_device(obj, device): class CheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, activation_offload=False, *args): check_backward_validity(args) @@ -42,7 +41,7 @@ def forward(ctx, run_function, activation_offload=False, *args): ctx.fwd_seed_states = get_states(copy=True) ctx.fwd_current_mode = get_current_mode() - if hasattr(torch, 'is_autocast_enabled'): + if hasattr(torch, "is_autocast_enabled"): ctx.had_autocast_in_fwd = torch.is_autocast_enabled() else: ctx.had_autocast_in_fwd = False @@ -62,7 +61,7 @@ def forward(ctx, run_function, activation_offload=False, *args): for i, arg in enumerate(args): if torch.is_tensor(arg): if activation_offload: - tensor_inputs.append(copy_to_device(arg, 'cpu')) + tensor_inputs.append(copy_to_device(arg, "cpu")) else: tensor_inputs.append(arg) ctx.tensor_indices.append(i) @@ -79,8 +78,10 @@ def forward(ctx, run_function, activation_offload=False, *args): @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter is " - "passed to .backward(). Please use .backward() and do not pass its `inputs` argument.") + raise RuntimeError( + "Checkpointing is not compatible with .grad() or when an `inputs` parameter is " + "passed to .backward(). Please use .backward() and do not pass its `inputs` argument." + ) # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices @@ -131,8 +132,7 @@ def backward(ctx, *args): outputs_with_grad.append(outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: - raise RuntimeError("none of output has requires_grad=True," - " this checkpoint() is not necessary") + raise RuntimeError("none of output has requires_grad=True," " this checkpoint() is not necessary") torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) return (None, None) + grads @@ -169,7 +169,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): fwd_current_mode = get_current_mode() # check if use autocast - if hasattr(torch, 'is_autocast_enabled'): + if hasattr(torch, "is_autocast_enabled"): has_autocast_in_fwd = torch.is_autocast_enabled() else: has_autocast_in_fwd = False @@ -179,7 +179,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): weak_holder_list = [] # class for weakref.ref - class Holder(): + class Holder: pass # return a Holder object for later unpack process @@ -226,19 +226,20 @@ def inner_unpack(packed): # rerun forward, the inner_pack will store all the activations in storage if has_autocast_in_fwd: - with torch.enable_grad(), \ - torch.cuda.amp.autocast(), \ - torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks( + inner_pack, inner_unpack + ): _unused = function(*args) else: - with torch.enable_grad(), \ - torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): + with torch.enable_grad(), torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): _unused = function(*args) if x not in storage: - raise RuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" - " recomputation being triggered in between, this is not currently supported. Please" - " open an issue with details on your use case so that we can prioritize adding this.") + raise RuntimeError( + "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported. Please" + " open an issue with details on your use case so that we can prioritize adding this." + ) return storage[x] diff --git a/colossalai/legacy/utils/checkpoint/__init__.py b/colossalai/legacy/utils/checkpoint/__init__.py index 558a956b31ac..35ce19ea1c69 100644 --- a/colossalai/legacy/utils/checkpoint/__init__.py +++ b/colossalai/legacy/utils/checkpoint/__init__.py @@ -1,3 +1,3 @@ from .module_checkpoint import load_checkpoint, save_checkpoint -__all__ = ['save_checkpoint', 'load_checkpoint'] +__all__ = ["save_checkpoint", "load_checkpoint"] diff --git a/colossalai/legacy/utils/checkpoint/module_checkpoint.py b/colossalai/legacy/utils/checkpoint/module_checkpoint.py index 9bd2907abf9d..1d691e5c8f97 100644 --- a/colossalai/legacy/utils/checkpoint/module_checkpoint.py +++ b/colossalai/legacy/utils/checkpoint/module_checkpoint.py @@ -9,13 +9,15 @@ from .utils import gather_tensor, scatter_tensor -def save_checkpoint(path: str, - epoch: int, - model: torch.nn.Module, - optimizer: Optional[OptimizerWrapper] = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - *args, - **kwargs): +def save_checkpoint( + path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[OptimizerWrapper] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + *args, + **kwargs, +): """save_checkpoint save a model, whose parameters are `ColoTensor`s. Args: @@ -30,7 +32,7 @@ def save_checkpoint(path: str, # save the dist context about the tensors in a new dict, while still maintain the original dict. for k, v in model_state.items(): if isinstance(v, ColoTensor): - gather_tensor(v) # gather shared tensors to rank0 + gather_tensor(v) # gather shared tensors to rank0 # don't recover tensors in rank0, since the dict is only a copy of model if rank == 0: @@ -39,10 +41,10 @@ def save_checkpoint(path: str, if isinstance(v, ColoTensor): assert v.save_ready assert v.is_replicate() - delattr(v, 'save_ready') + delattr(v, "save_ready") # model saving - save_state = {'epoch': epoch, 'model': model_state} - torch.save(save_state, path + '/epoch_{}_model.pth'.format(epoch), *args, **kwargs) + save_state = {"epoch": epoch, "model": model_state} + torch.save(save_state, path + "/epoch_{}_model.pth".format(epoch), *args, **kwargs) # delete old dicts del model_state @@ -52,35 +54,37 @@ def save_checkpoint(path: str, if optimizer is not None: mapping = dict() optim_state = optimizer.state_dict() - for k, v in optim_state['state'].items(): + for k, v in optim_state["state"].items(): for n, t in v.items(): if isinstance(t, ColoTensor): mapping[(k, n)] = t.dist_spec gather_tensor(t) if rank == 0: - save_state = {'epoch': epoch, 'optim': optim_state} - torch.save(save_state, path + '/epoch_{}_optim.pth'.format(epoch), *args, **kwargs) + save_state = {"epoch": epoch, "optim": optim_state} + torch.save(save_state, path + "/epoch_{}_optim.pth".format(epoch), *args, **kwargs) # recover colo tensors in rank0 - for k, v in optimizer.state_dict()['state'].items(): + for k, v in optimizer.state_dict()["state"].items(): for n, t in v.items(): if isinstance(t, ColoTensor): - assert hasattr(t, 'save_ready') + assert hasattr(t, "save_ready") t.set_dist_spec(mapping[(k, n)]) - delattr(t, 'save_ready') + delattr(t, "save_ready") del optim_state del mapping dist.barrier() -def load_checkpoint(path: str, - epoch: int, - model: torch.nn.Module, - optimizer: Optional[OptimizerWrapper] = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - torch_load_kwargs: Optional[Dict] = None, - load_state_dict_kwargs: Optional[Dict] = None): +def load_checkpoint( + path: str, + epoch: int, + model: torch.nn.Module, + optimizer: Optional[OptimizerWrapper] = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + torch_load_kwargs: Optional[Dict] = None, + load_state_dict_kwargs: Optional[Dict] = None, +): """load_checkpoint load a model, whose parameters are `ColoTensor`s. Args: @@ -106,8 +110,8 @@ def load_checkpoint(path: str, gather_tensor(p) if rank == 0: - load_state = torch.load(path + '/epoch_{}_model.pth'.format(epoch), **torch_load_kwargs) - model.load_state_dict(load_state['model'], **load_state_dict_kwargs) + load_state = torch.load(path + "/epoch_{}_model.pth".format(epoch), **torch_load_kwargs) + model.load_state_dict(load_state["model"], **load_state_dict_kwargs) dist.barrier() # scatter loaded parameters @@ -115,24 +119,24 @@ def load_checkpoint(path: str, if isinstance(p, ColoTensor): scatter_tensor(p, mapping[n]) if rank == 0: - assert hasattr(p, 'save_ready') - delattr(p, 'save_ready') + assert hasattr(p, "save_ready") + delattr(p, "save_ready") del mapping if optimizer is not None: mapping = dict() - for k, v in optimizer.state_dict()['state'].items(): + for k, v in optimizer.state_dict()["state"].items(): for n, t in v.items(): if isinstance(t, ColoTensor): mapping[(k, n)] = t.dist_spec gather_tensor(t) if rank == 0: - colo_checkpoint = torch.load(path + '/epoch_{}_optim.pth'.format(epoch), **torch_load_kwargs) - optimizer.load_state_dict(colo_checkpoint['optim'], **load_state_dict_kwargs) + colo_checkpoint = torch.load(path + "/epoch_{}_optim.pth".format(epoch), **torch_load_kwargs) + optimizer.load_state_dict(colo_checkpoint["optim"], **load_state_dict_kwargs) dist.barrier() - for k, v in optimizer.state_dict()['state'].items(): + for k, v in optimizer.state_dict()["state"].items(): for n, t in v.items(): if isinstance(t, ColoTensor): scatter_tensor(t, mapping[(k, n)]) diff --git a/colossalai/legacy/utils/checkpoint/utils.py b/colossalai/legacy/utils/checkpoint/utils.py index c830d4811463..c56848cf06c4 100644 --- a/colossalai/legacy/utils/checkpoint/utils.py +++ b/colossalai/legacy/utils/checkpoint/utils.py @@ -8,7 +8,7 @@ def robust_broadcast(tensor): with torch.no_grad(): - is_cpu_ten = tensor.device.type == 'cpu' + is_cpu_ten = tensor.device.type == "cpu" if is_cpu_ten: b_data = tensor.cuda() else: @@ -21,8 +21,7 @@ def robust_broadcast(tensor): def gather_tensor(colo_tensor: ColoTensor) -> None: - """Make colo_tensor replicated when the rank is 0 - """ + """Make colo_tensor replicated when the rank is 0""" if not colo_tensor.is_replicate(): pg = colo_tensor.get_process_group() # for the group which contains rank 0 @@ -36,12 +35,11 @@ def gather_tensor(colo_tensor: ColoTensor) -> None: dist.barrier() if dist.get_rank() == 0: - setattr(colo_tensor, 'save_ready', True) # set saving signature + setattr(colo_tensor, "save_ready", True) # set saving signature def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: - """Reversal operation of `gather_tensor`. - """ + """Reversal operation of `gather_tensor`.""" if dist_spec.placement == DistPlacementPattern.REPLICATE: robust_broadcast(colo_tensor.data) else: @@ -57,7 +55,8 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: colo_tensor.set_dist_spec(dist_spec) else: rep_tensor = ColoTensor( - entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) + entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec) + ) rep_tensor.set_dist_spec(dist_spec) with torch.no_grad(): colo_tensor.data.copy_(rep_tensor.data) diff --git a/colossalai/legacy/utils/checkpointing.py b/colossalai/legacy/utils/checkpointing.py index b7b29cc984d6..c068faafbf44 100644 --- a/colossalai/legacy/utils/checkpointing.py +++ b/colossalai/legacy/utils/checkpointing.py @@ -11,7 +11,7 @@ try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" from .common import is_using_pp @@ -25,10 +25,9 @@ def broadcast_state_dict(state_dict, parallel_mode): return state_dict[0] -def partition_tensor_parallel_state_dict(state_dict: OrderedDict, - parallel_mode: ParallelMode, - dims: dict = dict(), - partition_states: dict = dict()): +def partition_tensor_parallel_state_dict( + state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict() +): src_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) group = gpc.get_cpu_group(parallel_mode) @@ -65,11 +64,11 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict, def gather_tensor_parallel_state_dict( - state_dict: OrderedDict, - parallel_mode: ParallelMode, - dims: dict = dict(), - partition_states: dict = dict(), - keep_vars: bool = False, + state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict(), + keep_vars: bool = False, ): dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) @@ -138,8 +137,11 @@ def partition_pipeline_parallel_state_dict(model, state_dict): def gather_pipeline_parallel_state_dict(state_dict): - gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None) + gathered_states = ( + [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + else None + ) dist.gather_object( state_dict, gathered_states, @@ -147,18 +149,23 @@ def gather_pipeline_parallel_state_dict(state_dict): group=gpc.get_cpu_group(ParallelMode.PIPELINE), ) - state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict()) + state_dict = ( + OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + else OrderedDict() + ) return state_dict -def save_checkpoint(file, - epoch: int, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - **kwargs): +def save_checkpoint( + file, + epoch: int, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + **kwargs, +): """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler etc. into a checkpoint dictionary. @@ -196,8 +203,11 @@ def broadcast_model(model: torch.nn.Module): src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0] for p in model.parameters(): if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0: - group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group( - ParallelMode.TENSOR) + group = ( + gpc.get_group(ParallelMode.TENSOR) + if p.device.type == "cuda" + else gpc.get_cpu_group(ParallelMode.TENSOR) + ) dist.broadcast(p, src_rank, group=group) @@ -226,8 +236,9 @@ def load_checkpoint( Raises: RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated """ - state_dict = (torch.load(file, map_location=torch.device("cpu")) - if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None) + state_dict = ( + torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None + ) # model states model_state = state_dict.pop("model") if state_dict is not None else dict() @@ -246,8 +257,11 @@ def load_checkpoint( dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) if gpc.get_global_rank() == 0: all_error_msgs = list(chain.from_iterable(all_error_msgs)) - raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format( - model.__class__.__name__, "\n\t".join(all_error_msgs))) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(all_error_msgs) + ) + ) else: raise e diff --git a/colossalai/legacy/utils/common.py b/colossalai/legacy/utils/common.py index 35095161c2f2..671bcc3d6ad7 100644 --- a/colossalai/legacy/utils/common.py +++ b/colossalai/legacy/utils/common.py @@ -80,7 +80,6 @@ def is_using_sequence(): class model_branch_context(object): - def __enter__(self): self.env_status = env.save() @@ -98,16 +97,14 @@ def _calc_l2_norm(grads): if fused_optim is None: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() norm = 0.0 if len(grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( - fused_optim.multi_tensor_l2norm, - dummy_overflow_buf, - [grads], - False # no per-parameter norm + fused_optim.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm ) return norm @@ -121,7 +118,7 @@ def _calc_lp(grads, norm_type): def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: - if torch.is_tensor(norm) and norm.device.type != 'cuda': + if torch.is_tensor(norm) and norm.device.type != "cuda": norm = norm.to(torch.cuda.current_device()) return norm @@ -141,11 +138,11 @@ def _compute_local_lp(params: List[ColoParameter], norm_type: float) -> float: if len(params) == 0: return 0.0 grads = [p.grad for p in params] - use_cuda_kernel = grads[0].device.type == 'cuda' + use_cuda_kernel = grads[0].device.type == "cuda" if norm_type == inf: local_lp = max([g.abs().max() for g in grads]) elif norm_type == 2.0 and use_cuda_kernel: - local_lp = _calc_l2_norm(grads)**norm_type + local_lp = _calc_l2_norm(grads) ** norm_type else: local_lp = _calc_lp(grads, norm_type) if isinstance(local_lp, torch.Tensor): @@ -202,8 +199,8 @@ def _compute_grad_lp(parameters, norm_type: float = 2.0) -> float: assert isinstance(p, ColoParameter) if grad_dtype is None: grad_dtype = p.grad.dtype - assert p.grad.dtype == grad_dtype, f'Expected all grads are {grad_dtype}, got {p.grad.dtype}' - if p.grad.device.type == 'cuda': + assert p.grad.dtype == grad_dtype, f"Expected all grads are {grad_dtype}, got {p.grad.dtype}" + if p.grad.device.type == "cuda": cuda_grad_params.append(p) else: cpu_grad_params.append(p) @@ -221,7 +218,7 @@ def compute_grad_norm(parameters, norm_type: float = 2.0) -> float: norm_type = float(norm_type) total_norm = _compute_grad_lp(parameters, norm_type) if norm_type != inf: - total_norm = total_norm**(1 / norm_type) + total_norm = total_norm ** (1 / norm_type) return total_norm @@ -235,14 +232,15 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: for p in parameters: if p.grad is None: continue - if p.grad.device.type == 'cuda': + if p.grad.device.type == "cuda": cuda_grads.append(p.grad.detach()) else: cpu_grads.append(p.grad.detach()) if len(cuda_grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], - clip_coef) + multi_tensor_applier( + fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef + ) for g in cpu_grads: g.mul_(clip_coef) @@ -284,16 +282,17 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): for param in parameters: if param.grad is not None: # Make sure the grads are in fp32 - assert param.grad.dtype == torch.float, \ - f'expected gradient to be dtype torch.float, but got {param.grad.type()}' - if hasattr(param, 'colo_attr') and param.colo_attr.sharded_data_tensor.is_sharded: + assert ( + param.grad.dtype == torch.float + ), f"expected gradient to be dtype torch.float, but got {param.grad.type()}" + if hasattr(param, "colo_attr") and param.colo_attr.sharded_data_tensor.is_sharded: has_zero_shared_param = True params.append(param) if len(params) == 0: enable_cuda_kernels = False else: - enable_cuda_kernels = params[0].grad.device.type == 'cuda' + enable_cuda_kernels = params[0].grad.device.type == "cuda" # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) @@ -307,15 +306,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: - dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.MODEL), - async_op=False) + dist.all_reduce( + total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL), async_op=False + ) if has_zero_shared_param: - dist.all_reduce(total_norm_cuda, - op=dist.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.DATA), - async_op=False) + dist.all_reduce( + total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA), async_op=False + ) total_norm = total_norm_cuda[0].item() else: tensor_parallel_grads = [] @@ -323,17 +320,17 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): zero_sharded_grads = [] for p in params: if is_model_parallel_parameter(p): - reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) + reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS)) ** (1 / norm_type) tensor_parallel_grads.append(p.grad.data / reductor) - elif hasattr(p, 'colo_attr') and p.colo_attr.sharded_data_tensor.is_sharded: + elif hasattr(p, "colo_attr") and p.colo_attr.sharded_data_tensor.is_sharded: zero_sharded_grads.append(p.grad.data) else: no_tensor_parallel_grads.append(p.grad.data) if norm_type == 2.0 and enable_cuda_kernels: - tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type - no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type - zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type + tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads) ** norm_type + no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads) ** norm_type + zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type else: tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) @@ -358,7 +355,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): total_norm = tensor_parallel_norm + no_tensor_parallel_norm if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE)) - total_norm = total_norm**(1.0 / norm_type) + total_norm = total_norm ** (1.0 / norm_type) if torch.is_tensor(total_norm): total_norm = total_norm.item() @@ -397,13 +394,14 @@ def count_zeros_fp32(parameters): # Sum across all model-parallel GPUs. ops = [] ops.append( - dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) + dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True) + ) if gpc.is_initialized(ParallelMode.PIPELINE): ops.append( - dist.all_reduce(total_num_zeros, - op=dist.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PIPELINE), - async_op=True)) + dist.all_reduce( + total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE), async_op=True + ) + ) for req in ops: req.wait() @@ -420,8 +418,9 @@ def copy_tensor_parallel_attributes(src_tensor, dst_tensor): def param_is_not_tensor_parallel_duplicate(param): - return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or (gpc.get_local_rank( - ParallelMode.TENSOR) == 0) + return (hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL)) or ( + gpc.get_local_rank(ParallelMode.TENSOR) == 0 + ) @contextmanager diff --git a/colossalai/legacy/utils/data_sampler/__init__.py b/colossalai/legacy/utils/data_sampler/__init__.py index 12798a94c2d0..677d767667f2 100644 --- a/colossalai/legacy/utils/data_sampler/__init__.py +++ b/colossalai/legacy/utils/data_sampler/__init__.py @@ -1,4 +1,4 @@ from .base_sampler import BaseSampler from .data_parallel_sampler import DataParallelSampler, get_dataloader -__all__ = ['BaseSampler', 'DataParallelSampler', 'get_dataloader'] +__all__ = ["BaseSampler", "DataParallelSampler", "get_dataloader"] diff --git a/colossalai/legacy/utils/data_sampler/base_sampler.py b/colossalai/legacy/utils/data_sampler/base_sampler.py index 89f3bca5b1b5..c6b916fc4870 100644 --- a/colossalai/legacy/utils/data_sampler/base_sampler.py +++ b/colossalai/legacy/utils/data_sampler/base_sampler.py @@ -5,7 +5,6 @@ class BaseSampler(ABC): - def __init__(self, dataset, batch_size): self.dataset = dataset self.batch_size = batch_size diff --git a/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py index 66a5fdd3694d..41d0861e2249 100644 --- a/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py +++ b/colossalai/legacy/utils/data_sampler/data_parallel_sampler.py @@ -13,7 +13,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -T_co = TypeVar('T_co', covariant=True) +T_co = TypeVar("T_co", covariant=True) class DataParallelSampler(Sampler): @@ -44,11 +44,11 @@ def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_ self.num_samples = math.ceil( # `type:ignore` is required because Dataset cannot provide a default __len__ # see NOTE in pytorch/torch/utils/data/sampler.py - (len(self.dataset) - self.num_replicas) / \ - self.num_replicas # type: ignore[arg-type] + (len(self.dataset) - self.num_replicas) + / self.num_replicas # type: ignore[arg-type] ) else: - self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed @@ -65,7 +65,7 @@ def __iter__(self) -> Iterator[T_co]: # set_epoch manually self.epoch += 1 else: - indices = list(range(len(self.dataset))) # type: ignore[arg-type] + indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible @@ -76,11 +76,11 @@ def __iter__(self) -> Iterator[T_co]: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) @@ -99,14 +99,9 @@ def set_epoch(self, epoch: int) -> None: self.epoch = epoch -def get_dataloader(dataset, - shuffle=False, - seed=1024, - add_sampler=True, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): +def get_dataloader( + dataset, shuffle=False, seed=1024, add_sampler=True, drop_last=False, pin_memory=False, num_workers=0, **kwargs +): r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) Note: @@ -144,18 +139,22 @@ def seed_worker(worker_id): random.seed(worker_seed) if sampler is None: - return DataLoader(dataset, - worker_init_fn=seed_worker, - shuffle=shuffle, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + worker_init_fn=seed_worker, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) else: - return DataLoader(dataset, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/colossalai/legacy/utils/memory.py b/colossalai/legacy/utils/memory.py index 360bf0da4a77..2f99a7d2f72e 100644 --- a/colossalai/legacy/utils/memory.py +++ b/colossalai/legacy/utils/memory.py @@ -76,8 +76,10 @@ def report_memory_usage(message, logger=None, report_cpu=False): gpu_cached = _bytes_to_MB(torch.cuda.memory_reserved()) gpu_max_cached = _bytes_to_MB(torch.cuda.max_memory_reserved()) - full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \ + full_log = ( + f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " + f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB" + ) if report_cpu: # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports @@ -91,7 +93,7 @@ def report_memory_usage(message, logger=None, report_cpu=False): logger.info(full_log) # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ torch.cuda.reset_peak_memory_stats() @@ -106,10 +108,10 @@ def colo_device_memory_capacity(device: torch.device) -> int: int: size in byte """ assert isinstance(device, torch.device) - if device.type == 'cpu': + if device.type == "cpu": # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node - if device.type == 'cuda': + if device.type == "cuda": return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION @@ -123,16 +125,16 @@ def colo_device_memory_used(device: torch.device) -> int: Returns: int: memory size in bytes """ - if device.type == 'cpu': + if device.type == "cpu": mem_info = _get_cpu_memory_info() # In the context of 1-CPU-N-GPU, the memory usage of the current process is 1/N CPU memory used. # Each process consumes the same amount of memory. ret = mem_info.used / gpc.num_processes_on_current_node return ret - elif device.type == 'cuda': + elif device.type == "cuda": ret: int = torch.cuda.memory_allocated(device) # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ torch.cuda.reset_peak_memory_stats(device) return ret @@ -145,9 +147,9 @@ def colo_set_process_memory_fraction(ratio: float) -> None: Args: ratio (float): a ratio between 0. ~ 1. """ - if version.parse(torch.__version__) < version.parse('1.8'): - logger = get_dist_logger('colo_set_process_memory_fraction') - logger.warning('colo_set_process_memory_fraction failed because torch version is less than 1.8') + if version.parse(torch.__version__) < version.parse("1.8"): + logger = get_dist_logger("colo_set_process_memory_fraction") + logger.warning("colo_set_process_memory_fraction failed because torch version is less than 1.8") return global _GLOBAL_CUDA_MEM_FRACTION _GLOBAL_CUDA_MEM_FRACTION = ratio diff --git a/colossalai/legacy/utils/profiler/extention.py b/colossalai/legacy/utils/profiler/extention.py index 6726a683cc05..c07c6046bb1c 100644 --- a/colossalai/legacy/utils/profiler/extention.py +++ b/colossalai/legacy/utils/profiler/extention.py @@ -2,7 +2,6 @@ class ProfilerExtension(ABC): - @abstractmethod def prepare_trace(self): pass diff --git a/colossalai/legacy/utils/profiler/legacy/__init__.py b/colossalai/legacy/utils/profiler/legacy/__init__.py index 88beed86d7de..88b4201d8bf3 100644 --- a/colossalai/legacy/utils/profiler/legacy/__init__.py +++ b/colossalai/legacy/utils/profiler/legacy/__init__.py @@ -3,4 +3,4 @@ from .pcie_profiler import PcieProfiler from .prof_utils import BaseProfiler, ProfilerContext -__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] +__all__ = ["BaseProfiler", "CommProfiler", "PcieProfiler", "MemProfiler", "ProfilerContext"] diff --git a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py index bb7e2654c740..ad54b989f412 100644 --- a/colossalai/legacy/utils/profiler/legacy/comm_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/comm_profiler.py @@ -20,14 +20,14 @@ def _get_code_location(depth: int): upper_frame = inspect.stack()[i] function_name = inspect.stack()[i - 1].function ret.append(upper_frame.filename) - ret.append('(') + ret.append("(") ret.append(str(upper_frame.lineno)) - ret.append('): ') + ret.append("): ") ret.append(function_name) if i != length - 1: - ret.append('\n') + ret.append("\n") - return ''.join(ret) + return "".join(ret) torch_all_reduce = dist.all_reduce @@ -42,7 +42,7 @@ class CommEvent(object): volume recording. """ - def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): + def __init__(self, count: int = 0, comm_vol: float = 0.0, cuda_time: int = 0): self.self_count = count self.self_comm_vol = comm_vol self.self_cuda_time = cuda_time @@ -54,8 +54,7 @@ def add(self, rhs): class CommProfiler(BaseProfiler): - """Communication profiler. Records all communication events. - """ + """Communication profiler. Records all communication events.""" def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): super().__init__(profiler_name="Collective_Communication", priority=0) @@ -114,8 +113,10 @@ def append(s: str = None): res.append(sep) if self.warn_flag: - append("Warning: there exists multiple communication operations in the same time. As a result, " - "the profiling result is not accurate.") + append( + "Warning: there exists multiple communication operations in the same time. As a result, " + "the profiling result is not accurate." + ) if self.total_cuda_time == 0: return "No collective communication has been called yet!" @@ -126,24 +127,29 @@ def append(s: str = None): append("total number of calls: {}".format(self.total_count)) append("All events:") - separation = '-' * 74 - row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 + separation = "-" * 74 + row_format = "{:^10}" + "{:^12}" * 2 + "{:^16}" + "{:^12}" * 2 append(separation) - append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) + append(row_format.format("Location", "GPU time", "Percentage", "Comm volume", "Bandwidth", "Num of calls")) append(separation) show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) for location, event in show_list: append(location) append( - row_format.format('', _format_time(event.self_cuda_time), - '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0), - _format_memory(event.self_comm_vol), - _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count)) + row_format.format( + "", + _format_time(event.self_cuda_time), + "{:.1f}%".format(event.self_cuda_time / self.total_cuda_time * 100.0), + _format_memory(event.self_comm_vol), + _format_bandwidth(event.self_comm_vol, event.self_cuda_time), + event.self_count, + ) + ) append() - return ''.join(res) + return "".join(res) @property def has_aync_op(self): @@ -195,8 +201,7 @@ def wait_async_op(self): class CommHandler(object): - """Communication handler. A dummy handler to wait aync operations. - """ + """Communication handler. A dummy handler to wait aync operations.""" def __init__(self, profiler: CommProfiler): super().__init__() @@ -212,11 +217,9 @@ def async_check(profiler: CommProfiler): profiler.wait_async_op() -def all_reduce(tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def all_reduce( + tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, group=None, async_op: bool = False, profiler: CommProfiler = None +) -> Optional[CommHandler]: async_check(profiler) comm_size = dist.get_world_size(group) @@ -231,12 +234,14 @@ def all_reduce(tensor: torch.Tensor, profiler.close_profiler(group) -def reduce_scatter(output: torch.Tensor, - input_list: List[torch.Tensor], - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def reduce_scatter( + output: torch.Tensor, + input_list: List[torch.Tensor], + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None, +) -> Optional[CommHandler]: async_check(profiler) comm_size = dist.get_world_size(group) @@ -254,11 +259,13 @@ def reduce_scatter(output: torch.Tensor, profiler.close_profiler(group) -def all_gather(tensor_list: List[torch.Tensor], - tensor: torch.Tensor, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def all_gather( + tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op: bool = False, + profiler: CommProfiler = None, +) -> Optional[CommHandler]: async_check(profiler) comm_size = dist.get_world_size(group) @@ -276,11 +283,9 @@ def all_gather(tensor_list: List[torch.Tensor], profiler.close_profiler(group) -def broadcast(tensor: torch.Tensor, - src: int, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def broadcast( + tensor: torch.Tensor, src: int, group=None, async_op: bool = False, profiler: CommProfiler = None +) -> Optional[CommHandler]: async_check(profiler) comm_vol = 1.0 * tensor.element_size() * tensor.numel() @@ -293,12 +298,14 @@ def broadcast(tensor: torch.Tensor, profiler.close_profiler(group) -def reduce(tensor: torch.Tensor, - dst: int, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: +def reduce( + tensor: torch.Tensor, + dst: int, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None, +) -> Optional[CommHandler]: async_check(profiler) comm_vol = 1.0 * tensor.element_size() * tensor.numel() diff --git a/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py index 514d3c6fabfa..10a3f8dfc43b 100644 --- a/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py +++ b/colossalai/legacy/utils/profiler/legacy/pcie_profiler.py @@ -18,6 +18,7 @@ def _get_size(dtype: str): def _get_numel(my_list: List[int]) -> int: from functools import reduce from operator import mul + return reduce(mul, my_list) @@ -27,12 +28,11 @@ def _reduce_location(locations: List[str]) -> str: ret.append(lo) ret.append("\n") ret = ret[:-1] - return ''.join(ret) + return "".join(ret) class PcieEvent(object): - """Pcie Event. - """ + """Pcie Event.""" def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): self.count = count @@ -73,12 +73,9 @@ def reset(self): self.profiler = None def enable(self): - self.profiler = profile(enabled=True, - use_cuda=True, - use_cpu=True, - use_kineto=True, - record_shapes=True, - with_stack=True) + self.profiler = profile( + enabled=True, use_cuda=True, use_cpu=True, use_kineto=True, record_shapes=True, with_stack=True + ) self.profiler.__enter__() def disable(self): @@ -92,15 +89,15 @@ def disable(self): if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: continue current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) - code_location = _reduce_location(event.stack[:self.depth]) + code_location = _reduce_location(event.stack[: self.depth]) if code_location in self.ops_record: self.ops_record[code_location].add(current_comm_event) else: self.ops_record[code_location] = current_comm_event - elif 'Memcpy HtoD' in event.name: + elif "Memcpy HtoD" in event.name: self.h2d_count += 1 self.h2d_time += event.cuda_time_total - elif 'Memcpy DtoH' in event.name: + elif "Memcpy DtoH" in event.name: self.d2h_count += 1 self.d2h_time += event.cuda_time_total @@ -132,19 +129,25 @@ def append(s: str = None): append("Possible data transmission events in PCIE:") - separation = '-' * 62 - row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 + separation = "-" * 62 + row_format = "{:^10}" + "{:^12}" + "{:^16}" + "{:^12}" * 2 append(separation) - append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) + append(row_format.format("Location", "GPU time", "Trans volume", "Bandwidth", "Num of calls")) append(separation) show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) for location, event in show_list: append(location) append( - row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol), - _format_bandwidth(event.pcie_vol, event.cuda_time), event.count)) + row_format.format( + "", + _format_time(event.cuda_time), + _format_memory(event.pcie_vol), + _format_bandwidth(event.pcie_vol, event.cuda_time), + event.count, + ) + ) append() - return ''.join(res) + return "".join(res) diff --git a/colossalai/legacy/utils/profiler/legacy/prof_utils.py b/colossalai/legacy/utils/profiler/legacy/prof_utils.py index 9b948c9ec1cd..95eecf0715e7 100644 --- a/colossalai/legacy/utils/profiler/legacy/prof_utils.py +++ b/colossalai/legacy/utils/profiler/legacy/prof_utils.py @@ -11,10 +11,10 @@ def _format_time(time_us): US_IN_SECOND = 1000.0 * 1000.0 US_IN_MS = 1000.0 if time_us >= US_IN_SECOND: - return '{:.3f}s'.format(time_us / US_IN_SECOND) + return "{:.3f}s".format(time_us / US_IN_SECOND) if time_us >= US_IN_MS: - return '{:.3f}ms'.format(time_us / US_IN_MS) - return '{:.3f}us'.format(time_us) + return "{:.3f}ms".format(time_us / US_IN_MS) + return "{:.3f}us".format(time_us) # copied from high version pytorch to support low version @@ -23,28 +23,27 @@ def _format_memory(nbytes): KB = 1024 MB = 1024 * KB GB = 1024 * MB - if (abs(nbytes) >= GB): - return '{:.2f} GB'.format(nbytes * 1.0 / GB) - elif (abs(nbytes) >= MB): - return '{:.2f} MB'.format(nbytes * 1.0 / MB) - elif (abs(nbytes) >= KB): - return '{:.2f} KB'.format(nbytes * 1.0 / KB) + if abs(nbytes) >= GB: + return "{:.2f} GB".format(nbytes * 1.0 / GB) + elif abs(nbytes) >= MB: + return "{:.2f} MB".format(nbytes * 1.0 / MB) + elif abs(nbytes) >= KB: + return "{:.2f} KB".format(nbytes * 1.0 / KB) else: - return str(nbytes) + ' B' + return str(nbytes) + " B" def _format_bandwidth(volume: float or int, time_us: int): - sec_div_mb = (1000.0 / 1024.0)**2 + sec_div_mb = (1000.0 / 1024.0) ** 2 mb_per_sec = volume / time_us * sec_div_mb if mb_per_sec >= 1024.0: - return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) + return "{:.3f} GB/s".format(mb_per_sec / 1024.0) else: - return '{:.3f} MB/s'.format(mb_per_sec) + return "{:.3f} MB/s".format(mb_per_sec) class BaseProfiler(ABC): - def __init__(self, profiler_name: str, priority: int): self.name = profiler_name self.priority = priority @@ -111,8 +110,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def to_tensorboard(self, writer): from torch.utils.tensorboard import SummaryWriter - assert isinstance(writer, SummaryWriter), \ - f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' + assert isinstance( + writer, SummaryWriter + ), f"torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}." for prof in self.profilers: prof.to_tensorboard(writer) @@ -124,7 +124,7 @@ def to_file(self, log_dir: Union[str, Path]): if not log_dir.exists(): log_dir.mkdir(parents=True, exist_ok=True) for prof in self.profilers: - log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') + log_file = log_dir.joinpath(f"{prof.name}_rank_{gpc.get_global_rank()}.log") prof.to_file(log_file) def show(self): diff --git a/colossalai/legacy/utils/profiler/profiler.py b/colossalai/legacy/utils/profiler/profiler.py index 0827f06b586c..b7a75f25d951 100644 --- a/colossalai/legacy/utils/profiler/profiler.py +++ b/colossalai/legacy/utils/profiler/profiler.py @@ -120,26 +120,30 @@ def trace_handler(prof): p.step() """ - def __init__(self, - *, - activities: Optional[Iterable[ProfilerActivity]] = None, - schedule: Optional[Callable[[int], ProfilerAction]] = None, - on_trace_ready: Optional[Callable[..., Any]] = None, - engine: Optional[Engine] = None, - record_shapes: bool = False, - profile_memory: bool = False, - with_stack: bool = False, - with_flops: bool = False, - with_modules: bool = False, - profile_stateful_tensor_memory: bool = False) -> None: - super().__init__(activities=activities, - schedule=schedule, - on_trace_ready=on_trace_ready, - record_shapes=record_shapes, - profile_memory=profile_memory, - with_stack=with_stack, - with_flops=with_flops, - with_modules=with_modules) + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + engine: Optional[Engine] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + profile_stateful_tensor_memory: bool = False, + ) -> None: + super().__init__( + activities=activities, + schedule=schedule, + on_trace_ready=on_trace_ready, + record_shapes=record_shapes, + profile_memory=profile_memory, + with_stack=with_stack, + with_flops=with_flops, + with_modules=with_modules, + ) self._logger = get_dist_logger() self.extentions: List[ProfilerExtension] = [] if profile_stateful_tensor_memory: @@ -149,9 +153,9 @@ def __init__(self, self.extentions.append(StatefulTensorMemoryProfilerExtention(engine)) def prepare_trace(self) -> None: - if hasattr(super(), 'prepare_trace'): + if hasattr(super(), "prepare_trace"): super().prepare_trace() - elif hasattr(super(), '_start_warmup'): + elif hasattr(super(), "_start_warmup"): super()._start_warmup() for ext in self.extentions: ext.prepare_trace() @@ -160,9 +164,9 @@ def _start_warmup(self): self.prepare_trace() def start_trace(self): - if hasattr(super(), '_start_trace'): + if hasattr(super(), "_start_trace"): super()._start_trace() - elif hasattr(super(), 'start_trace'): + elif hasattr(super(), "start_trace"): super().start_trace() for ext in self.extentions: ext.start_trace() @@ -171,9 +175,9 @@ def _start_trace(self): self.start_trace() def stop_trace(self): - if hasattr(super(), '_stop_trace'): + if hasattr(super(), "_stop_trace"): super()._stop_trace() - elif hasattr(super(), 'stop_trace'): + elif hasattr(super(), "stop_trace"): super().stop_trace() for ext in self.extentions: ext.stop_trace() @@ -186,15 +190,15 @@ def export_chrome_trace(self, path: str): Exports the collected trace in Chrome JSON format. """ assert self.profiler - fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False) + fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) fp.close() retvalue = self.profiler.export_chrome_trace(fp.name) with open(fp.name) as fin: trace = json.load(fin) for ext in self.extentions: trace = ext.extend_chrome_trace(trace) - open_func = gzip.open if path.endswith('.gz') else open - with open_func(path, 'wt') as fout: + open_func = gzip.open if path.endswith(".gz") else open + with open_func(path, "wt") as fout: json.dump(trace, fout) os.remove(fp.name) diff --git a/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py index f3bb66ced583..9247a9b80772 100644 --- a/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py +++ b/colossalai/legacy/utils/profiler/stateful_tensor_mem_extention.py @@ -22,11 +22,11 @@ def get_timestamp_us(): def generic_instant_event(name, pid, tid, timestamp, args): - return {'ph': 'i', 's': 't', 'name': name, 'pid': pid, 'tid': tid, 'ts': timestamp, 'args': args} + return {"ph": "i", "s": "t", "name": name, "pid": pid, "tid": tid, "ts": timestamp, "args": args} class StatefulTensorMemoryEvent: - EVENT_NAME = '[statefulTensorMemory]' + EVENT_NAME = "[statefulTensorMemory]" def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None: self.pid = os.getpid() @@ -37,22 +37,23 @@ def __init__(self, timestamp: int, device_type: DeviceType, bytes_: int) -> None self.bytes = bytes_ def state_dict(self): - return generic_instant_event(StatefulTensorMemoryEvent.EVENT_NAME, self.pid, self.tid, self.timestamp, { - 'Device Type': self.device_type.value, - 'Device Id': self.device_id, - 'Bytes': self.bytes - }) + return generic_instant_event( + StatefulTensorMemoryEvent.EVENT_NAME, + self.pid, + self.tid, + self.timestamp, + {"Device Type": self.device_type.value, "Device Id": self.device_id, "Bytes": self.bytes}, + ) class StatefulTensorMemoryTracer: - def __init__(self) -> None: self.events: List[StatefulTensorMemoryEvent] = [] self._tracing = False def sample(self): - cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] - cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu'] + cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"] + cpu_mem = StatefulTensor.GST_MGR.total_mem["cpu"] timestamp = get_timestamp_us() if self._tracing: self.events.append(StatefulTensorMemoryEvent(timestamp, DeviceType.CUDA, cuda_mem)) @@ -70,7 +71,6 @@ def state_dict(self): class StatefulTensorMemoryTracerHook(BaseOpHook): - def __init__(self, tracer: StatefulTensorMemoryTracer): super().__init__() self.tracer = tracer @@ -104,7 +104,6 @@ def disable(self): class StatefulTensorMemoryProfilerExtention(ProfilerExtension): - def __init__(self, engine: Engine) -> None: self.engine = engine self.tracer = StatefulTensorMemoryTracer() @@ -131,5 +130,5 @@ def stop_trace(self): # self.hook_registered = False def extend_chrome_trace(self, trace: dict) -> dict: - trace['traceEvents'].extend(self.tracer.state_dict()) + trace["traceEvents"].extend(self.tracer.state_dict()) return trace diff --git a/colossalai/legacy/zero/__init__.py b/colossalai/legacy/zero/__init__.py index 3783d38e61b2..760fd529f3a6 100644 --- a/colossalai/legacy/zero/__init__.py +++ b/colossalai/legacy/zero/__init__.py @@ -11,8 +11,9 @@ from .sharded_optim import ShardedOptimizerV2 -def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, - optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: +def convert_to_zero_v2( + model: nn.Module, optimizer: torch.optim.Optimizer, model_config, optimizer_config +) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: """ A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading @@ -25,12 +26,12 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model :rtype: Tuple """ - logger = get_dist_logger('convert_to_zero_v2') + logger = get_dist_logger("convert_to_zero_v2") - logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) + logger.info(f"optimizer_config is {optimizer_config}", ranks=[0]) if optimizer_config is None: optimizer_config = dict() - logger.info(f'model_config is {model_config}', ranks=[0]) + logger.info(f"model_config is {model_config}", ranks=[0]) if model_config is None: model_config = dict() @@ -40,6 +41,12 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model __all__ = [ - 'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context', - 'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy' + "convert_to_zero_v2", + "ShardedModelV2", + "ShardedOptimizerV2", + "ZeroInitContext", + "no_shard_zero_context", + "no_shard_zero_decrator", + "TensorShardStrategy", + "BucketTensorShardStrategy", ] diff --git a/colossalai/legacy/zero/gemini/__init__.py b/colossalai/legacy/zero/gemini/__init__.py index 754ae9bc0044..b272980d34d8 100644 --- a/colossalai/legacy/zero/gemini/__init__.py +++ b/colossalai/legacy/zero/gemini/__init__.py @@ -4,6 +4,11 @@ from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy __all__ = [ - 'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy', - 'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook' + "StatefulTensorMgr", + "StatefulTensor", + "CPUTensorPlacementPolicy", + "CUDATensorPlacementPolicy", + "AutoTensorPlacementPolicy", + "register_ophooks_recursively", + "BaseOpHook", ] diff --git a/colossalai/legacy/zero/gemini/gemini_context.py b/colossalai/legacy/zero/gemini/gemini_context.py index 9a7da6b80fba..9e82d948fba7 100644 --- a/colossalai/legacy/zero/gemini/gemini_context.py +++ b/colossalai/legacy/zero/gemini/gemini_context.py @@ -2,16 +2,15 @@ class GeminiMemoryManager(object): - def __init__(self, states_cls: EnumMeta): super().__init__() self.states_cls = states_cls - self._cnter = 0 # the counter of instances + self._cnter = 0 # the counter of instances self.total_mem = dict() self.state_mem = dict() - self.state_mem['cpu'] = dict() - self.state_mem['cuda'] = dict() + self.state_mem["cpu"] = dict() + self.state_mem["cuda"] = dict() self.reset() @@ -20,15 +19,15 @@ def total_number(self): return self._cnter def reset(self): - self._cnter = 0 # the counter of instances + self._cnter = 0 # the counter of instances - self.total_mem['cpu'] = 0 # memory occupation of instances in cpu - self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda + self.total_mem["cpu"] = 0 # memory occupation of instances in cpu + self.total_mem["cuda"] = 0 # memory of occupation of instances in cuda # memory conditions for all states for state in self.states_cls: - self.state_mem['cpu'][state] = 0 - self.state_mem['cuda'][state] = 0 + self.state_mem["cpu"][state] = 0 + self.state_mem["cuda"][state] = 0 def register_new_instance(self): self._cnter += 1 @@ -37,12 +36,16 @@ def delete_instance(self): self._cnter -= 1 def print_info(self): - print(f"Total number: {self.total_number}", - f"Total CPU memory occupation: {self.total_mem['cpu']}", - f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", - sep='\n') + print( + f"Total number: {self.total_number}", + f"Total CPU memory occupation: {self.total_mem['cpu']}", + f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", + sep="\n", + ) for state in self.states_cls: - print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", - f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", - sep='\n') + print( + f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", + f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", + sep="\n", + ) diff --git a/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py index d68a9dc6458f..4129b14bcae9 100644 --- a/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py +++ b/colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py @@ -22,7 +22,7 @@ def post_fwd_exec(self, module: torch.nn.Module, *args): def pre_bwd_exec(self, module: torch.nn.Module, input, output): for param in module.parameters(): - assert hasattr(param, '_sharded_grad') + assert hasattr(param, "_sharded_grad") param._sharded_grad.setup() def post_bwd_exec(self, module: torch.nn.Module, input): diff --git a/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py b/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py index 6b76a2116a49..e0c83eec0445 100644 --- a/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py +++ b/colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py @@ -19,25 +19,25 @@ def niter(self): def pre_fwd_exec(self, module: torch.nn.Module, *args): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.gather() param.data = param.ca_attr.payload() def post_fwd_exec(self, module: torch.nn.Module, *args): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.shard() param.data = param.ca_attr.payload() def pre_bwd_exec(self, module: torch.nn.Module, input, output): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.gather() param.data = param.ca_attr.payload() def post_bwd_exec(self, module: torch.nn.Module, input): for param in module.parameters(): - assert hasattr(param, 'ca_attr') + assert hasattr(param, "ca_attr") param.ca_attr.shard() param.data = param.ca_attr.payload() diff --git a/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py index eebcf86e0e58..57076063cb3f 100644 --- a/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py @@ -15,8 +15,7 @@ class TrainingPhase(Enum): BACKWARD = 1 -class GradMemStats(): - +class GradMemStats: def __init__(self) -> None: self.unreleased_grad_flag = {} self.unreleased_grad_volume = 0 @@ -26,8 +25,7 @@ def clear(self): self.unreleased_grad_volume = 0 -class GradMemTracerHook(): - +class GradMemTracerHook: def __init__(self, grad_stats: GradMemStats): self.grad_hook_list = [] self._grad_stats = grad_stats @@ -50,7 +48,6 @@ def remove_grad_hook(self): class ParamMemTracerHook(ColoParamOpHook): - def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None: super().__init__() self._training_phase = TrainingPhase.FORWARD @@ -79,10 +76,9 @@ def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]): if cur_dev == "cpu": if p.grad is not None and p.grad.device.type == "cpu": raise NotImplementedError("Only run in forward propagation") - p.data = torch.empty(p.data.shape, - device="cuda", - dtype=p.data.dtype, - requires_grad=p.data.requires_grad) + p.data = torch.empty( + p.data.shape, device="cuda", dtype=p.data.dtype, requires_grad=p.data.requires_grad + ) elif cur_dev == "cuda": alloc_storage(p.data) diff --git a/colossalai/legacy/zero/gemini/ophooks/utils.py b/colossalai/legacy/zero/gemini/ophooks/utils.py index f88ad2b00e9e..057906156d8d 100644 --- a/colossalai/legacy/zero/gemini/ophooks/utils.py +++ b/colossalai/legacy/zero/gemini/ophooks/utils.py @@ -48,7 +48,6 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs): class PreBackwardFunction(torch.autograd.Function): - @staticmethod def forward(ctx, module, pre_backward_function, outputs): ctx.module = module @@ -64,7 +63,6 @@ def backward(ctx, *args): class PostBackwardFunction(torch.autograd.Function): - @staticmethod def forward(ctx, module, pre_backward_function, output): ctx.module = module @@ -84,16 +82,15 @@ def backward(ctx, *args): return (None, None) + args -def register_ophooks_recursively(module: torch.nn.Module, - ophook_list: List[BaseOpHook], - name: str = "", - filter_fn: Optional[Callable] = None): +def register_ophooks_recursively( + module: torch.nn.Module, ophook_list: List[BaseOpHook], name: str = "", filter_fn: Optional[Callable] = None +): r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD.""" assert isinstance(module, torch.nn.Module) assert isinstance(ophook_list, (list, tuple)) - assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0' + assert len(ophook_list) > 0, "expected at least 1 hook in the argument ophook_list but found 0" for hook in ophook_list: - assert (isinstance(hook, BaseOpHook)) + assert isinstance(hook, BaseOpHook) # Add hooks for submodules for child_name, child in module.named_children(): @@ -118,7 +115,6 @@ def _post_forward_module_hook(submodule, *args): hook.post_fwd_exec(submodule, *args) def _pre_backward_module_hook(submodule, inputs, output): - def _run_before_backward_function(submodule): for hook in ophook_list: assert isinstance(submodule, torch.nn.Module) @@ -127,7 +123,6 @@ def _run_before_backward_function(submodule): return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output) def _post_backward_module_hook(submodule, inputs): - def _run_after_backward_function(submodule): for hook in ophook_list: assert isinstance(submodule, torch.nn.Module) diff --git a/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py b/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py index 84f32be358e3..91c7bdc2961b 100644 --- a/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py +++ b/colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py @@ -5,7 +5,6 @@ class BaseParamHookMgr(object): - def __init__(self, param_list: List[torch.nn.Parameter]) -> None: r""" register backward hook on every parameters of module @@ -23,9 +22,9 @@ def register_backward_hooks(self, hook_call: Callable) -> None: ``` """ if not torch.is_grad_enabled(): - return # don't register grad hooks if grad isn't enabled + return # don't register grad hooks if grad isn't enabled for p in self._param_list: - if p.requires_grad and not hasattr(p, '_base_param_hook'): + if p.requires_grad and not hasattr(p, "_base_param_hook"): handle = p.register_hook(functools.partial(hook_call, p)) p._base_param_hook = handle @@ -35,5 +34,5 @@ def remove_hooks(self) -> None: """ for p in self._param_list: - if p.requires_grad and hasattr(p, '_base_param_hook'): + if p.requires_grad and hasattr(p, "_base_param_hook"): p._base_param_hook.remove() diff --git a/colossalai/legacy/zero/gemini/stateful_tensor.py b/colossalai/legacy/zero/gemini/stateful_tensor.py index 1619ae40798d..668d344132d0 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor.py @@ -25,13 +25,14 @@ class StatefulTensor(object): https://arxiv.org/abs/2108.05818 """ + # Global Stateful Tensor Manager GST_MGR = GeminiMemoryManager(TensorState) def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: self._state = state self._payload = None - self._payload_size = 0 # byte size of current payload + self._payload_size = 0 # byte size of current payload StatefulTensor.GST_MGR.register_new_instance() @@ -47,7 +48,7 @@ def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorS def data_ptr(self): if self._payload is None: - return 0 # if a tensor has no storage, 0 should be returned + return 0 # if a tensor has no storage, 0 should be returned return self._payload.data_ptr() def set_null(self) -> None: @@ -80,7 +81,7 @@ def move_to(self, device: Union[torch.device, int]): assert self.state is not TensorState.FREE, "Can't move free stateful tensor" if not isinstance(device, torch.device): - to_device = torch.device('cuda', device) + to_device = torch.device("cuda", device) else: to_device = device @@ -97,7 +98,6 @@ def payload_copy(self, tensor) -> None: self._payload.view(-1).copy_(tensor.view(-1)) def payload_reset(self, tensor) -> None: - assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead" if self.payload is not None: @@ -168,8 +168,7 @@ def __release(self): self._payload_size = 0 def __trans_state_update(self, from_state: TensorState, to_state: TensorState): - """Update global manager when changing the state of a tensor - """ + """Update global manager when changing the state of a tensor""" manager = StatefulTensor.GST_MGR size = self.payload_size device_type = self.device.type @@ -189,8 +188,7 @@ def __trans_state_update(self, from_state: TensorState, to_state: TensorState): manager.total_mem[device_type] -= size def __trans_device_update(self, from_type: str, to_type: str): - """Update global manager when changing the device of a tensor - """ + """Update global manager when changing the device of a tensor""" manager = StatefulTensor.GST_MGR size = self.payload_size state = self.state diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py index 4f9ea7c6d520..19f77d4305af 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py @@ -3,14 +3,11 @@ from time import time from typing import List -import torch - -from colossalai.logging import get_dist_logger from colossalai.utils.cuda import get_current_device from .stateful_tensor import StatefulTensor, TensorState from .tensor_placement_policy import TensorPlacementPolicy -from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from .tensor_utils import colo_model_data_tensor_move_inline class StatefulTensorMgr(object): @@ -44,8 +41,7 @@ def start_iter(self): pass def finish_iter(self): - """This function must be called when each iteration finishes - """ + """This function must be called when each iteration finishes""" self._warmup = False self._compute_idx = -1 self._cpu_gpu_move_volume = 0 @@ -53,19 +49,21 @@ def finish_iter(self): self._evict_time = 0 def adjust_layout(self) -> None: - """ Adjust the layout of stateful tensor according to the information provided + """Adjust the layout of stateful tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE - cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE] + cuda_demand = StatefulTensor.GST_MGR.state_mem["cpu"][TensorState.COMPUTE] start = time() move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup) self._layout_time += time() - start - vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list, - cuda_demand=cuda_demand, - warmup=self._warmup, - compute_list=self._compute_list, - compute_idx=self._compute_idx) + vol, evict_time = self._tensor_placement_policy.evict_tensors( + hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx, + ) self._cpu_gpu_move_volume += vol self._evict_time += evict_time # move COMPUTE tensors to CUDA @@ -92,10 +90,10 @@ def _get_layout_info(self, compute_idx: int, warmup: bool): if tensor.state == TensorState.FREE: continue - if tensor.device.type == 'cuda': + if tensor.device.type == "cuda": if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: hold_cuda_tensor_list.append(tensor) - elif tensor.device.type == 'cpu': + elif tensor.device.type == "cpu": if tensor.state == TensorState.COMPUTE: move_to_cuda_tensor_list.append(tensor) else: diff --git a/colossalai/legacy/zero/gemini/tensor_placement_policy.py b/colossalai/legacy/zero/gemini/tensor_placement_policy.py index 275933ec2cfb..3aca80cfe56a 100644 --- a/colossalai/legacy/zero/gemini/tensor_placement_policy.py +++ b/colossalai/legacy/zero/gemini/tensor_placement_policy.py @@ -10,11 +10,10 @@ from colossalai.zero.gemini.memory_tracer import MemStatsCollector from .stateful_tensor import StatefulTensor -from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from .tensor_utils import colo_model_data_tensor_move_inline class TensorPlacementPolicy(ABC): - def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None: self.device: Optional[torch.device] = device self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector @@ -25,9 +24,8 @@ def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) - class CPUTensorPlacementPolicy(TensorPlacementPolicy): - def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: - super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector) + super().__init__(torch.device("cpu"), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: volume = 0 @@ -38,9 +36,8 @@ def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) - class CUDATensorPlacementPolicy(TensorPlacementPolicy): - def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: - assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' + assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available" super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: @@ -48,7 +45,6 @@ def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) - class AutoTensorPlacementPolicy(TensorPlacementPolicy): - def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: super().__init__(None, mem_stats_collector=mem_stats_collector) # model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase @@ -56,13 +52,15 @@ def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> N self._warmup_non_model_data_ratio: float = 0.8 self._steady_cuda_cap_ratio: float = 0.9 - def evict_tensors(self, - hold_cuda_tensor_list: List[StatefulTensor], - cuda_demand: int = 0, - warmup: bool = True, - compute_list: List[StatefulTensor] = [], - compute_idx: int = 0, - **kwargs) -> int: + def evict_tensors( + self, + hold_cuda_tensor_list: List[StatefulTensor], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: List[StatefulTensor] = [], + compute_idx: int = 0, + **kwargs, + ) -> int: """ Evict tensors from CUDA device. @@ -81,13 +79,13 @@ def evict_tensors(self, """ start = time() cuda_capacity = colo_device_memory_capacity(get_current_device()) - used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda'] + used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") cuda_capacity *= self._steady_cuda_cap_ratio total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data @@ -99,15 +97,16 @@ def evict_tensors(self, to_free_cuda_model_data = cuda_demand - avail_cuda_model_data to_free_tensor_list = hold_cuda_tensor_list if not warmup: - to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx, - tuple(compute_list)) + to_free_tensor_list = self._sort_hold_cuda_tensors( + tuple(hold_cuda_tensor_list), compute_idx, tuple(compute_list) + ) # print(self._sort_hold_cuda_tensors.cache_info()) end = time() for t in to_free_tensor_list: if freed_cuda_model_data >= to_free_cuda_model_data: break freed_cuda_model_data += t.payload_size - colo_model_data_tensor_move_inline(t, torch.device('cpu')) + colo_model_data_tensor_move_inline(t, torch.device("cpu")) if freed_cuda_model_data < to_free_cuda_model_data: raise RuntimeError( f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" @@ -126,14 +125,13 @@ def _sort_hold_cuda_tensors(hold_cuda_tensors: tuple, compute_idx: int, compute_ class TensorPlacementPolicyFactory: - @staticmethod def create(policy_name: str) -> Type[TensorPlacementPolicy]: - if policy_name == 'cpu': + if policy_name == "cpu": return CPUTensorPlacementPolicy - elif policy_name == 'cuda': + elif policy_name == "cuda": return CUDATensorPlacementPolicy - elif policy_name == 'auto': + elif policy_name == "auto": return AutoTensorPlacementPolicy else: raise TypeError(f"Unknown tensor placement policy {policy_name}") diff --git a/colossalai/legacy/zero/gemini/tensor_utils.py b/colossalai/legacy/zero/gemini/tensor_utils.py index 843e330ee2c6..6e51dee6ef94 100644 --- a/colossalai/legacy/zero/gemini/tensor_utils.py +++ b/colossalai/legacy/zero/gemini/tensor_utils.py @@ -30,16 +30,17 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[ cuda_use, cpu_use = 0, 0 mem_use = t.storage().size() * t.element_size() - if t.device.type == 'cuda': + if t.device.type == "cuda": cuda_use += mem_use - elif t.device.type == 'cpu': + elif t.device.type == "cpu": cpu_use += mem_use return cuda_use, cpu_use -def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, - torch.Tensor]) -> None: +def colo_model_data_tensor_move( + src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor, torch.Tensor] +) -> None: """ A colossal API for model data tensor move. The src and target tensors could be resident on both CPU and GPU. @@ -71,8 +72,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_ src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype) -def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, - int]) -> None: +def colo_model_data_tensor_move_inline( + t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device, int] +) -> None: """ move a tensor to the target_device Args: @@ -80,14 +82,14 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t target_device: a target device, if type is int, it the index of cuda card. """ if not isinstance(target_device, torch.device): - target_device = torch.device(f'cuda:{target_device}') + target_device = torch.device(f"cuda:{target_device}") if isinstance(t, torch.Tensor): t.data = t.data.to(target_device) elif isinstance(t, StatefulTensor): t.move_to(target_device) else: - raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}') + raise TypeError(f"colo_model_data_tensor_move_inline dose not accept type {type(t)}") def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None: @@ -100,9 +102,9 @@ def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None: if isinstance(t, torch.Tensor): t.data = t.data.cpu() elif isinstance(t, StatefulTensor): - t.move_to(torch.device('cpu')) + t.move_to(torch.device("cpu")) else: - raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}') + raise TypeError(f"colo_model_data_move_to_cpu dose not accept type {type(t)}") def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor: diff --git a/colossalai/legacy/zero/init_ctx/__init__.py b/colossalai/legacy/zero/init_ctx/__init__.py index 0a6f81566a9d..28ce72a18b31 100644 --- a/colossalai/legacy/zero/init_ctx/__init__.py +++ b/colossalai/legacy/zero/init_ctx/__init__.py @@ -1,3 +1,3 @@ from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator -__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator'] +__all__ = ["ZeroInitContext", "no_shard_zero_context", "no_shard_zero_decrator"] diff --git a/colossalai/legacy/zero/init_ctx/init_context.py b/colossalai/legacy/zero/init_ctx/init_context.py index 4a7e46408583..6c5a8122ef80 100644 --- a/colossalai/legacy/zero/init_ctx/init_context.py +++ b/colossalai/legacy/zero/init_ctx/init_context.py @@ -39,7 +39,7 @@ def __post_init__(self): assert self.is_replicated, "Non-replicated parameters can't be sharded." if self.is_replicated and not self.shard_param: - assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda." + assert self.target_device.type == "cuda", "Replicated no-shard parameters should be located in cuda." class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): @@ -59,15 +59,16 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). """ - def __init__(self, - target_device: torch.device, - shard_strategy: BaseShardStrategy, - seed: int = 2**10 - 1, - shard_param: bool = False, - default_dtype: Optional[torch.dtype] = None, - bf16: bool = False, - model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)): - + def __init__( + self, + target_device: torch.device, + shard_strategy: BaseShardStrategy, + seed: int = 2**10 - 1, + shard_param: bool = False, + default_dtype: Optional[torch.dtype] = None, + bf16: bool = False, + model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long), + ): super().__init__(default_dtype=default_dtype) self.shard_strategy = shard_strategy self.param_list = [] @@ -103,7 +104,7 @@ def calc_fanin_fanout(tensor: torch.Tensor): assert isinstance(tensor, nn.Parameter), "Sharded tensor initialization is only allowed for parameters" # get correct shape of input tensor - if not hasattr(tensor, 'colo_attr') or not tensor.colo_attr.param_is_sharded: + if not hasattr(tensor, "colo_attr") or not tensor.colo_attr.param_is_sharded: tensor_shape = tensor.shape else: tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape @@ -137,13 +138,16 @@ def _pre_context_exec(self): self.module_load_from_state_dict = nn.Module._load_from_state_dict shard_strategy = self.shard_strategy if self.config.shard_param else None - nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict, - shard_strategy=shard_strategy) + nn.Module._load_from_state_dict = functools.partialmethod( + ShardedModelV2._colo_load_from_state_dict, shard_strategy=shard_strategy + ) self.module_state_dict = nn.Module.state_dict - nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict, - shard_strategy=shard_strategy, - state_dict_func=self.module_state_dict, - process_group=self.dp_process_group) + nn.Module.state_dict = functools.partialmethod( + ShardedModelV2._colo_state_dict, + shard_strategy=shard_strategy, + state_dict_func=self.module_state_dict, + process_group=self.dp_process_group, + ) # reserve rng states self.cpu_rng_state = torch.get_rng_state() @@ -152,16 +156,15 @@ def _pre_context_exec(self): # set new seed for initialization, since we initialize sharded tensor separately # we don't want all processes have the same seed # otherwise all sharded tensors are same after init - offset = self.seed + 1 # we want to have more 1 in binary format seed + offset = self.seed + 1 # we want to have more 1 in binary format seed torch.manual_seed(self.seed + offset * dist.get_rank()) def _post_context_exec(self): - """The callback function when exiting context. - """ + """The callback function when exiting context.""" # broadcast replicated no-shard parameters src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] for param in self.param_list: - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated: dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group) param.colo_attr.set_data_none() @@ -193,7 +196,7 @@ def half_fn(t: torch.Tensor): for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice - if hasattr(param, 'colo_attr'): + if hasattr(param, "colo_attr"): continue self.param_numel[param] = param.numel() @@ -216,7 +219,7 @@ def half_fn(t: torch.Tensor): if self.shard_param: self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) - param.data = param.colo_attr.data_payload # set param.data to payload + param.data = param.colo_attr.data_payload # set param.data to payload # mark whether the param is replicated param.colo_attr.is_replicated = self.is_replicated @@ -251,15 +254,13 @@ def hijack_context_config(self, **kwargs): def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager: - return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()), - is_replicated=is_replicated, - shard_param=False) + return ZeroContextMgr().hijack_context_config( + target_device=torch.device("cuda", torch.cuda.current_device()), is_replicated=is_replicated, shard_param=False + ) def no_shard_zero_decrator(is_replicated: bool = True): - def _wrapper(init_func): - def _no_shard(*args, **kwargs): with no_shard_zero_context(is_replicated): ret = init_func(*args, **kwargs) diff --git a/colossalai/legacy/zero/shard_utils/__init__.py b/colossalai/legacy/zero/shard_utils/__init__.py index 5e5d63a7e768..945c77a412c1 100644 --- a/colossalai/legacy/zero/shard_utils/__init__.py +++ b/colossalai/legacy/zero/shard_utils/__init__.py @@ -2,4 +2,4 @@ from .bucket_tensor_shard_strategy import BucketTensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy -__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] +__all__ = ["BaseShardStrategy", "TensorShardStrategy", "BucketTensorShardStrategy"] diff --git a/colossalai/legacy/zero/shard_utils/base_shard_strategy.py b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py index 9fb80f57ae77..13e6f0e48298 100644 --- a/colossalai/legacy/zero/shard_utils/base_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/base_shard_strategy.py @@ -7,10 +7,8 @@ class BaseShardStrategy(ABC): - def __init__(self) -> None: - """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs. - """ + """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.""" super().__init__() @abstractmethod diff --git a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py index 1f7baad57816..b9d3071a877e 100644 --- a/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -18,7 +18,6 @@ class BucketTensorShardStrategy(TensorShardStrategy): """ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): - tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] if len(tensor_list) == 0: return @@ -40,8 +39,8 @@ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist. buffer_list = [buffer.to(target_device) for buffer in buffer_list] offset = 0 for i, t in enumerate(tensor_list): - gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list] - gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape) + gathered_payload = [buffer[offset : offset + tensor_numels[i]] for buffer in buffer_list] + gathered_payload = torch.cat(gathered_payload)[: t.origin_numel].view(t.origin_shape) t.payload_reset(gathered_payload) t.is_sharded = False offset += tensor_numels[i] diff --git a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py index cc43907f6655..ebaef774bd06 100644 --- a/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py @@ -24,7 +24,7 @@ def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist. self._gather_tensor(t, process_group) def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): - """ Shard tensor among processes. + """Shard tensor among processes. Args: t (ShardedTensor): a tensor to be sharded. @@ -33,9 +33,11 @@ def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGr """ if t.is_sharded: return - if t.payload.device.type == 'cuda': - assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ + if t.payload.device.type == "cuda": + assert t.payload.device == get_current_device(), ( + f"shard tensor on cuda device index {t.payload.device.index}," f" but current cuda device is {get_current_device()}" + ) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.payload_reset(sharded_payload) t.is_sharded = True diff --git a/colossalai/legacy/zero/sharded_model/__init__.py b/colossalai/legacy/zero/sharded_model/__init__.py index 93120bdc34b4..ecead2f6a657 100644 --- a/colossalai/legacy/zero/sharded_model/__init__.py +++ b/colossalai/legacy/zero/sharded_model/__init__.py @@ -1,3 +1,3 @@ from .sharded_model_v2 import ShardedModelV2 -__all__ = ['ShardedModelV2'] +__all__ = ["ShardedModelV2"] diff --git a/colossalai/legacy/zero/sharded_model/_utils.py b/colossalai/legacy/zero/sharded_model/_utils.py index b8a618ef5a0d..100762318593 100644 --- a/colossalai/legacy/zero/sharded_model/_utils.py +++ b/colossalai/legacy/zero/sharded_model/_utils.py @@ -25,7 +25,7 @@ def free_storage(data: torch.Tensor) -> None: @torch.no_grad() def alloc_storage(data: torch.Tensor, size: torch.Size) -> None: """Allocate storage for a tensor.""" - if data.storage().size() == size.numel(): # no need to reallocate + if data.storage().size() == size.numel(): # no need to reallocate return assert data.storage().size() == 0 data.storage().resize_(size.numel()) diff --git a/colossalai/legacy/zero/sharded_model/reduce_scatter.py b/colossalai/legacy/zero/sharded_model/reduce_scatter.py index 4fb507382df9..0f11365515d2 100644 --- a/colossalai/legacy/zero/sharded_model/reduce_scatter.py +++ b/colossalai/legacy/zero/sharded_model/reduce_scatter.py @@ -20,7 +20,6 @@ class Bucket: - def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device) self.group = group @@ -35,18 +34,18 @@ def flush(self) -> None: return # reduce-scatter bucket if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives: - dist._reduce_scatter_base(self.output_shard[:self.offset], - self.buffer[:, :self.offset].contiguous(), - group=self.group) + dist._reduce_scatter_base( + self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group + ) else: - dist.reduce_scatter(self.output_shard[:self.offset], - list(self.buffer[:, :self.offset].unbind(0)), - group=self.group) + dist.reduce_scatter( + self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group + ) # execute post-reduction callbacks for callback_fn in self.callbacks: callback_fn() # reuse input bucket but allocate a fresh output shard - self.buffer[:, :self.offset].zero_() + self.buffer[:, : self.offset].zero_() self.offset = 0 self.callbacks.clear() self.output_shard = torch.zeros_like(self.buffer[0]) @@ -74,12 +73,12 @@ def append(self, tensor_list: List[Tensor], callback_fn: Callable): tensor_size = tensor_list[0].numel() stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size) offset = self.offset - self.buffer[:, offset:offset + tensor_size].copy_(stacked_input) + self.buffer[:, offset : offset + tensor_size].copy_(stacked_input) self.offset += tensor_size # callback will be given the reduced result if callback_fn is not None: - result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0]) + result_view = self.output_shard[offset : offset + tensor_size].view_as(tensor_list[0]) self.callbacks.append(functools.partial(callback_fn, result_view)) @@ -142,8 +141,9 @@ def reduce_scatter_async( """ world_size = group.size() - assert (len(input_list) == world_size - ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" + assert ( + len(input_list) == world_size + ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" first_input = input_list[0] first_input_size = first_input.numel() @@ -183,7 +183,7 @@ def free(self) -> None: @functools.lru_cache() def _get_shard_size(self, element_size: int, num_shards: int) -> int: - if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. return 0 MB = 1024 * 1024 bucket_size = self.bucket_size_mb * MB / element_size diff --git a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py index 91c21ccf9516..85f2ac2159f4 100644 --- a/colossalai/legacy/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/legacy/zero/sharded_model/sharded_model_v2.py @@ -2,7 +2,6 @@ import functools import itertools from collections import OrderedDict -from copy import deepcopy from typing import Any, Iterator, Optional, Tuple import torch @@ -40,7 +39,7 @@ try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" class ShardedModelV2(nn.Module): @@ -78,20 +77,22 @@ class ShardedModelV2(nn.Module): bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False. """ - def __init__(self, - module: nn.Module, - shard_strategy: BaseShardStrategy, - process_group: Optional[ProcessGroup] = None, - reduce_scatter_process_group: Optional[ProcessGroup] = None, - reduce_scatter_bucket_size_mb: int = 25, - fp32_reduce_scatter: bool = False, - tensor_placement_policy: str = 'cuda', - gradient_predivide_factor: Optional[float] = 1.0, - reuse_fp16_shard: bool = False, - bf16: bool = False, - *args, - **kwargs): - assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' + def __init__( + self, + module: nn.Module, + shard_strategy: BaseShardStrategy, + process_group: Optional[ProcessGroup] = None, + reduce_scatter_process_group: Optional[ProcessGroup] = None, + reduce_scatter_bucket_size_mb: int = 25, + fp32_reduce_scatter: bool = False, + tensor_placement_policy: str = "cuda", + gradient_predivide_factor: Optional[float] = 1.0, + reuse_fp16_shard: bool = False, + bf16: bool = False, + *args, + **kwargs, + ): + assert not isinstance(module, ShardedModelV2), "Nested ShardedModelV2 is not supported." super().__init__() self.logger = get_dist_logger() self.bf16 = bf16 @@ -101,13 +102,13 @@ def __init__(self, sharded_cnt = 0 unshard_cnt = 0 for param in submodule.parameters(recurse=False): - assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.' + assert hasattr(param, "colo_attr"), "You must use ZeroInitContext to init your module first." if param.colo_attr.param_is_sharded: sharded_cnt += 1 else: unshard_cnt += 1 - assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param' - submodule.param_is_sharded = (sharded_cnt > 0) + assert (not sharded_cnt) or (not unshard_cnt), "nn.Module can not both have shard param and unshard param" + submodule.param_is_sharded = sharded_cnt > 0 self.sharded_params = [] self.unshard_params = [] @@ -124,7 +125,7 @@ def __init__(self, self.rank = dist.get_rank(self.process_group) self.shard_strategy = shard_strategy - self._use_memory_tracer = tensor_placement_policy == 'auto' + self._use_memory_tracer = tensor_placement_policy == "auto" if self._use_memory_tracer: self._memstats_collector = MemStatsCollector() self._start_collect_memstats = disposable(self._memstats_collector.start_collection) @@ -132,18 +133,19 @@ def __init__(self, else: self._memstats_collector = None self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create( - tensor_placement_policy)(mem_stats_collector=self._memstats_collector) + tensor_placement_policy + )(mem_stats_collector=self._memstats_collector) - if 'warmup_non_model_data_ratio' in kwargs: - if tensor_placement_policy != 'auto': - self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement') + if "warmup_non_model_data_ratio" in kwargs: + if tensor_placement_policy != "auto": + self.logger.warning("setting warmup_non_model_data_ratio is useless if not use auto placement") else: - ratio = kwargs['warmup_non_model_data_ratio'] + ratio = kwargs["warmup_non_model_data_ratio"] self._tensor_placement_policy._warmup_non_model_data_ratio = ratio - self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement') + self.logger.info(f"setting warmup_non_model_data_ratio as {ratio} for auto placement") self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) - param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')] + param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, "colo_attr")] self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list) # Register hooks @@ -155,7 +157,7 @@ def __init__(self, self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.fp32_reduce_scatter = fp32_reduce_scatter - self._cpu_offload: bool = tensor_placement_policy != 'cuda' + self._cpu_offload: bool = tensor_placement_policy != "cuda" for param in module.parameters(): # Init `offload_grad` param.colo_attr.offload_grad = self._cpu_offload @@ -164,9 +166,11 @@ def __init__(self, # So we use 1.0 as the default gradient_predivide_factor # However, if you set gradient_predivide_factor to None, we will set # gradient_predivide_factor to a value >= 1.0 automatically - self.gradient_predivide_factor: float = gradient_predivide_factor if \ - gradient_predivide_factor is not None else \ - get_gradient_predivide_factor(self.world_size) + self.gradient_predivide_factor: float = ( + gradient_predivide_factor + if gradient_predivide_factor is not None + else get_gradient_predivide_factor(self.world_size) + ) self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() @@ -194,7 +198,7 @@ def cuda_margin_space(self): def cpu_offload(self): return self._cpu_offload - def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None: + def dump_memory_stats(self, filename: Optional[str] = "dump_mem_stats.log") -> None: """ dummy memory tracer collected information to a file. try: @@ -205,18 +209,18 @@ def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> N exit(0) """ if self._use_memory_tracer: - self.logger.error(f'dump memory tracer collected information to a {filename}', ranks=[0]) + self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0]) if gpc.get_global_rank() == 0: - with open(filename, 'w+') as f: - f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') - f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') - f.write('CUDA model data (GB)\n') - f.write('\n') - f.write('CUDA non model data (GB)\n') - f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda'))) - f.write('CPU non model data (GB)\n') - f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu'))) - f.write('\n') + with open(filename, "w+") as f: + f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n") + f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n") + f.write("CUDA model data (GB)\n") + f.write("\n") + f.write("CUDA non model data (GB)\n") + f.write(str(self._memstats_collector._memstats.non_model_data_list("cuda"))) + f.write("CPU non model data (GB)\n") + f.write(str(self._memstats_collector._memstats.non_model_data_list("cpu"))) + f.write("\n") def _pre_forward_operations(self, *args): # the operation will affect the memory tracer behavior in ZeroHook @@ -224,14 +228,14 @@ def _pre_forward_operations(self, *args): self._start_collect_memstats() for p in self.module.parameters(): - if hasattr(p, 'colo_attr'): + if hasattr(p, "colo_attr"): p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) self._stateful_tensor_mgr.start_iter() def _post_forward_operations(self): for p in self.module.parameters(): - if hasattr(p, 'colo_attr'): + if hasattr(p, "colo_attr"): p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: @@ -261,8 +265,9 @@ def _update_memstats(self): # the way to calculate margin space is based on the assumption that # model data is fixed in cuda during training. # cuda margin space can be used to store OS. - self._cuda_margin_space = colo_device_memory_capacity( - get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + self._cuda_margin_space = ( + colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda + ) @torch.no_grad() def _post_backward_operations(self) -> None: @@ -330,7 +335,7 @@ def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Opti """ if grad is None: return - assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients' + assert not grad.requires_grad, "ShardedModel only works with gradients that don't require gradients" if not self._require_backward_grad_sync: return # used to cheat Pytorch, since we can't return None @@ -354,16 +359,19 @@ def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None: grad.data.div_(self.gradient_predivide_factor) if self.world_size > 1: grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size()) - self.reducer.reduce_scatter_async(grad_chunks, - group=self.reduce_scatter_process_group, - callback_fn=functools.partial(self._reduce_scatter_callback, param)) + self.reducer.reduce_scatter_async( + grad_chunks, + group=self.reduce_scatter_process_group, + callback_fn=functools.partial(self._reduce_scatter_callback, param), + ) else: self._reduce_scatter_callback(param, grad) torch.cuda.current_stream().wait_stream(self.comm_stream) def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: - assert isinstance(reduced_grad, - torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" + assert isinstance( + reduced_grad, torch.Tensor + ), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" reduced_grad.data = reduced_grad.data.contiguous().view(-1) if self.gradient_postdivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. @@ -372,7 +380,6 @@ def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) # FIXME(ver217): refactor the below line when impl eviction policy def _save_grad(self, param: Parameter, grad: torch.Tensor): - # record whether we have overflow self.overflow_counter += torch.isinf(grad).any().item() self.overflow_counter += torch.isnan(grad).any().item() @@ -384,8 +391,9 @@ def _save_grad(self, param: Parameter, grad: torch.Tensor): if self.reuse_fp16_shard: # make parameters point to gradient - assert param.colo_attr.saved_grad.is_null( - ), 'Gradient accumulation is not supported when reuse_fp16_shard=True' + assert ( + param.colo_attr.saved_grad.is_null() + ), "Gradient accumulation is not supported when reuse_fp16_shard=True" param.colo_attr.grad_payload_reset(grad.data) # release the memory of param @@ -396,7 +404,6 @@ def _save_grad(self, param: Parameter, grad: torch.Tensor): if param.colo_attr.is_replicated: param.colo_attr.sharded_data_tensor.is_sharded = True else: - fp32_grad = cast_tensor_to_fp32(grad) if param.colo_attr.saved_grad.is_null(): @@ -410,39 +417,44 @@ def _save_grad(self, param: Parameter, grad: torch.Tensor): def parameters(self, recurse: bool = True) -> Iterator[Parameter]: return self.module.parameters(recurse=recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: return self.module.named_parameters(prefix, recurse) - def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': - return self._colo_state_dict(destination, - prefix, - keep_vars, - shard_strategy=self.shard_strategy, - state_dict_func=nn.Module.state_dict, - module_to_load=self.module, - sharded_params=self.sharded_params, - process_group=self.process_group) - - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None: + def state_dict(self, destination=None, prefix="", keep_vars=False) -> "OrderedDict[str, torch.Tensor]": + return self._colo_state_dict( + destination, + prefix, + keep_vars, + shard_strategy=self.shard_strategy, + state_dict_func=nn.Module.state_dict, + module_to_load=self.module, + sharded_params=self.sharded_params, + process_group=self.process_group, + ) + + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True) -> None: for name, p in self.named_parameters(): if name in state_dict: - p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, - device=p.colo_attr.data_payload.device)) + p.colo_attr.data_payload_reset( + state_dict[name].to(dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device) + ) # Force re-shard p.colo_attr.sharded_data_tensor.is_sharded = False self.shard_strategy.shard([p.colo_attr.sharded_data_tensor]) elif strict: - raise RuntimeError(f'Missing key in state_dict: {name}') - - def _colo_state_dict(self, - destination=None, - prefix='', - keep_vars=False, - shard_strategy: Optional[BaseShardStrategy] = None, - state_dict_func=None, - module_to_load=None, - sharded_params=[], - process_group=None) -> 'OrderedDict[str, torch.Tensor]': + raise RuntimeError(f"Missing key in state_dict: {name}") + + def _colo_state_dict( + self, + destination=None, + prefix="", + keep_vars=False, + shard_strategy: Optional[BaseShardStrategy] = None, + state_dict_func=None, + module_to_load=None, + sharded_params=[], + process_group=None, + ) -> "OrderedDict[str, torch.Tensor]": if len(sharded_params) == 0: for param in self.parameters(): if param.colo_attr.param_is_sharded: @@ -460,15 +472,9 @@ def _colo_state_dict(self, p.colo_attr.set_data_none() return gathered_state_dict - def _colo_load_from_state_dict(self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - shard_strategy=None): + def _colo_load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs, shard_strategy=None + ): r"""Copies parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this @@ -512,10 +518,12 @@ def _colo_load_from_state_dict(self, key = prefix + name if key in state_dict: input_param = state_dict[key] - if hasattr(param, 'colo_attr'): + if hasattr(param, "colo_attr"): param.colo_attr.data_payload_reset( - input_param.to(dtype=param.colo_attr.data_payload.dtype, - device=param.colo_attr.data_payload.device)) + input_param.to( + dtype=param.colo_attr.data_payload.dtype, device=param.colo_attr.data_payload.device + ) + ) if shard_strategy is not None: # Force re-shard param.colo_attr.sharded_data_tensor.is_sharded = False @@ -531,19 +539,21 @@ def _colo_load_from_state_dict(self, if not is_param_lazy and input_param.shape != param.shape: # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format( - key, input_param.shape, param.shape)) + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) continue try: with torch.no_grad(): param.copy_(input_param) except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), - ex.args)) + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) elif strict: missing_keys.append(key) @@ -559,8 +569,8 @@ def _colo_load_from_state_dict(self, if strict: for key in state_dict.keys(): if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] - input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child if input_name not in self._modules and input_name not in local_state: unexpected_keys.append(key) diff --git a/colossalai/legacy/zero/sharded_model/utils.py b/colossalai/legacy/zero/sharded_model/utils.py index 7a411669900b..cb085f19e6b2 100644 --- a/colossalai/legacy/zero/sharded_model/utils.py +++ b/colossalai/legacy/zero/sharded_model/utils.py @@ -11,7 +11,7 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu Note the other_model has to be the same as self. """ for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): - assert hasattr(zero_param, 'colo_attr') + assert hasattr(zero_param, "colo_attr") shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded if shard_flag: sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor]) diff --git a/colossalai/legacy/zero/sharded_model/zero_hook.py b/colossalai/legacy/zero/sharded_model/zero_hook.py index 3fc373e5ca44..892e9f31ded4 100644 --- a/colossalai/legacy/zero/sharded_model/zero_hook.py +++ b/colossalai/legacy/zero/sharded_model/zero_hook.py @@ -20,11 +20,13 @@ class ZeroHook(BaseOpHook): Warning: this class has been deprecated after version 0.1.12 """ - def __init__(self, - shard_strategy: BaseShardStrategy, - memstarts_collector: Optional[MemStatsCollector] = None, - stateful_tensor_mgr: Optional[StatefulTensorMgr] = None, - process_group: Optional[dist.ProcessGroup] = None): + def __init__( + self, + shard_strategy: BaseShardStrategy, + memstarts_collector: Optional[MemStatsCollector] = None, + stateful_tensor_mgr: Optional[StatefulTensorMgr] = None, + process_group: Optional[dist.ProcessGroup] = None, + ): super().__init__() self.logger = get_dist_logger("ZeROHook") self.shard_strategy = shard_strategy @@ -41,7 +43,7 @@ def gather_parameters(self, module: torch.nn.Module): if module.param_is_sharded: tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) @@ -50,7 +52,7 @@ def shard_parameters(self, module: torch.nn.Module): if module.param_is_sharded: tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) @@ -74,10 +76,9 @@ def pre_fwd_exec(self, module: torch.nn.Module, *args): self.gather_parameters(module) for param in module.parameters(recurse=False): param.data = param.colo_attr.data_payload - assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA" + assert param.data.device.type == "cuda", f"PRE FWD param.data must be on CUDA" def post_fwd_exec(self, module: torch.nn.Module, *args): - # change tensor state to HOLD_AFTER_FWD for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) @@ -93,10 +94,9 @@ def pre_bwd_exec(self, module: torch.nn.Module, input, output): self.gather_parameters(module) for param in module.parameters(recurse=False): param.data = param.colo_attr.data_payload - assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA" + assert param.data.device.type == "cuda", f"PRE BWD param.data must be on CUDA" def post_bwd_exec(self, module: torch.nn.Module, input): - # change tensor state to HOLD_AFTER_BWD for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) @@ -114,5 +114,6 @@ def post_iter(self): if self._stateful_tensor_mgr: self.logger.debug( f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}", - ranks=[0]) + ranks=[0], + ) self._stateful_tensor_mgr.finish_iter() diff --git a/colossalai/legacy/zero/sharded_optim/__init__.py b/colossalai/legacy/zero/sharded_optim/__init__.py index b71a70aeffa4..700fb0eb91d3 100644 --- a/colossalai/legacy/zero/sharded_optim/__init__.py +++ b/colossalai/legacy/zero/sharded_optim/__init__.py @@ -1,3 +1,3 @@ from .sharded_optim_v2 import ShardedOptimizerV2 -__all__ = ['ShardedOptimizerV2'] +__all__ = ["ShardedOptimizerV2"] diff --git a/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py index e21f1cea04df..e73679163fab 100644 --- a/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/legacy/zero/sharded_optim/sharded_optim_v2.py @@ -1,6 +1,5 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch from enum import Enum -from os import stat from typing import Dict, Optional, Tuple import torch @@ -74,22 +73,24 @@ class ShardedOptimizerV2(OptimizerWrapper): https://arxiv.org/abs/2108.05818 """ - def __init__(self, - sharded_model: ShardedModelV2, - optimizer: Optimizer, - gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**32, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - dp_process_group: Optional[ProcessGroup] = None, - mp_process_group: Optional[ProcessGroup] = None, - verbose: bool = False) -> None: - assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' - assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.' + def __init__( + self, + sharded_model: ShardedModelV2, + optimizer: Optimizer, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + dp_process_group: Optional[ProcessGroup] = None, + mp_process_group: Optional[ProcessGroup] = None, + verbose: bool = False, + ) -> None: + assert isinstance(sharded_model, ShardedModelV2), "model must be wrapped with ShardedModel" + assert not isinstance(optimizer, ShardedOptimizerV2), "Nested ShardedOptimizerV2 is not supported." super().__init__(optimizer) self.shard_strategy = sharded_model.shard_strategy @@ -97,39 +98,49 @@ def __init__(self, self.bf16 = sharded_model.bf16 self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) - assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0" # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, # and it must set `num_fp32_shards_per_param` correctly - self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr( - optimizer, 'num_fp32_shards_per_param', 0) >= 2 - self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu') + self._should_move_fp32_shards_h2d: bool = ( + sharded_model.cpu_offload + and self.gpu_margin_mem_ratio > 0.0 + and getattr(optimizer, "num_fp32_shards_per_param", 0) >= 2 + ) + self.device = sharded_model._tensor_placement_policy.device or torch.device("cpu") self.optim_state: OptimState = OptimState.UNSCALED self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL) # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device()) self._logger = get_dist_logger("ShardedOptimizerV2") self._verbose = verbose - self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward + self._grad_prepared: bool = ( + False # this should be set to true when _prepare_grads() and reset to false when backward + ) # Store fp32 param shards self._register_master_weight() - if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy, - AutoTensorPlacementPolicy): - self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', - ranks=[0]) + if self.gpu_margin_mem_ratio != 0.0 and not isinstance( + sharded_model._tensor_placement_policy, AutoTensorPlacementPolicy + ): + self._logger.warning( + f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"', ranks=[0] + ) if self._verbose: self._logger.debug( - f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]) + f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0] + ) self._use_memory_tracer = self.model.use_memory_tracer @@ -138,7 +149,7 @@ def loss_scale(self): return self.grad_scaler.scale.item() def get_memory_usage(self) -> Tuple[int, int]: - """ Get the memory usage of the optimizer. Including master_params (param fp32), + """Get the memory usage of the optimizer. Including master_params (param fp32), momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``) Returns: @@ -157,7 +168,7 @@ def update_mem_use(t): for _, p_fp32 in self.master_params.items(): update_mem_use(p_fp32) for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: state = self.optim.state[p] for k, v in state.items(): update_mem_use(v) @@ -191,7 +202,6 @@ def clip_grad_norm(self, model: nn.Module, max_norm: float): return super().clip_grad_norm(model, max_norm) def step(self, *args, **kwargs): - self._prepare_grads() # unscale grads if scaled if not self.bf16 and self.optim_state == OptimState.SCALED: @@ -203,7 +213,7 @@ def step(self, *args, **kwargs): self.grad_scaler.update(found_inf) if found_inf: - self._logger.warning('found inf during ShardedOptimV2 step') + self._logger.warning("found inf during ShardedOptimV2 step") self._zero_grad(recover_data=True) return @@ -213,14 +223,16 @@ def step(self, *args, **kwargs): gpu_mem, cpu_mem = self.get_memory_usage() self._logger.debug( f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", - ranks=[0]) + ranks=[0], + ) ret = self.optim.step(*args, **kwargs) if self._verbose: gpu_mem, cpu_mem = self.get_memory_usage() self._logger.debug( f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!", - ranks=[0]) + ranks=[0], + ) self._copy_master_model_to_model_fp16() return ret @@ -240,7 +252,7 @@ def _check_overflow(self): def _unscale_grads(self): assert self.optim_state == OptimState.SCALED for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is not None: p.grad.data.div_(self.loss_scale) self.optim_state = OptimState.UNSCALED @@ -260,16 +272,16 @@ def _zero_grad(self, recover_data: bool = False): # Which leads to wrong accumulation self.optim.zero_grad(set_to_none=True) for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: # p.colo_attr.sharded_data_tensor stores grad now # we have to recover fp16 param - reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0) + reuse_fp16_shard = p.colo_attr.sharded_data_tensor.payload_size == 0 if recover_data and reuse_fp16_shard: self._copy_master_param_to_param_fp16(p) else: # release saved gradient p.colo_attr.saved_grad.set_null() - self.model.overflow_counter = 0 # set overflow counter to zero + self.model.overflow_counter = 0 # set overflow counter to zero def sync_grad(self): pass @@ -277,8 +289,8 @@ def sync_grad(self): def _register_master_weight(self): self.master_params: Dict[Parameter, StatefulTensor] = {} for group in self.optim.param_groups: - for p in group['params']: - assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam' + for p in group["params"]: + assert hasattr(p, "colo_attr"), "The parameter must be wrapped with ShardedParam" shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated if shard_flag: # we always shard replicated parameters @@ -296,7 +308,7 @@ def _maybe_move_fp32_shards(self): fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param fp32_shards_used_cuda_margin_mem = 0 for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: if p.colo_attr.saved_grad.is_null(): continue shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size() @@ -314,7 +326,7 @@ def _prepare_grads(self): if self._grad_prepared: return for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: if p.colo_attr.saved_grad.is_null(): continue p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE) @@ -335,7 +347,7 @@ def _point_param_fp16_to_master_param(self): # assign master param pointers to p.data. # We will not trigger data copy here. for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: self.master_params[p].trans_state(TensorState.COMPUTE) p.data = self.master_params[p].payload # Now p.data is sharded @@ -346,7 +358,7 @@ def _copy_master_model_to_model_fp16(self): # TODO() improve efficiency by gathering tensors into a chunk and transferring # a chunk. for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: self._copy_master_param_to_param_fp16(p) def _copy_master_param_to_param_fp16(self, p): @@ -364,7 +376,8 @@ def _copy_master_param_to_param_fp16(self, p): # in order to use copy, otherwise, the sizes of tensor is not compatible if p.colo_attr.data_payload.numel() != p.data.numel(): p.colo_attr.data_payload_reset( - torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)) + torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device) + ) # TODO() optimize this line CPU (fp32) -> GPU (fp16) half_dtype = torch.bfloat16 if self.bf16 else torch.float16 @@ -373,7 +386,7 @@ def _copy_master_param_to_param_fp16(self, p): if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: # We gather full fp16 param here - p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True + p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) self.master_params[p].trans_state(TensorState.HOLD) @@ -381,18 +394,18 @@ def _copy_master_param_to_param_fp16(self, p): def state_dict(self): optim_state_dict = super().state_dict() scaler_state_dict = self.grad_scaler.state_dict() - optim_state_dict['scaler'] = scaler_state_dict + optim_state_dict["scaler"] = scaler_state_dict return optim_state_dict def load_state_dict(self, *args, **kwargs): - if 'scaler' not in args[0]: - self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) + if "scaler" not in args[0]: + self._logger.warning("Missing scaler when loading optimizer state dict", ranks=[0]) else: - scaler_state_dict = args[0].pop('scaler') + scaler_state_dict = args[0].pop("scaler") self.grad_scaler.load_state_dict(scaler_state_dict) super().load_state_dict(*args, **kwargs) for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: state = self.optim.state[p] for k, v in state.items(): if isinstance(v, Tensor): diff --git a/colossalai/legacy/zero/sharded_param/__init__.py b/colossalai/legacy/zero/sharded_param/__init__.py index 47e2ce2fa0e0..c7afb95391a4 100644 --- a/colossalai/legacy/zero/sharded_param/__init__.py +++ b/colossalai/legacy/zero/sharded_param/__init__.py @@ -1,4 +1,4 @@ from .sharded_param import ShardedParamV2 from .sharded_tensor import ShardedTensor -__all__ = ['ShardedTensor', 'ShardedParamV2'] +__all__ = ["ShardedTensor", "ShardedParamV2"] diff --git a/colossalai/legacy/zero/sharded_param/sharded_param.py b/colossalai/legacy/zero/sharded_param/sharded_param.py index 454a722cf7e7..22b09d5ff4bb 100644 --- a/colossalai/legacy/zero/sharded_param/sharded_param.py +++ b/colossalai/legacy/zero/sharded_param/sharded_param.py @@ -19,7 +19,6 @@ def get_empty_tensor(device: torch.device, dtype: torch.dtype): class ShardedParamV2(object): - def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None: self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) @@ -36,8 +35,7 @@ def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> No self.set_data_none() def get_payload_tensors(self) -> List[StatefulTensor]: - """returns stateful tensors kept by this class. - """ + """returns stateful tensors kept by this class.""" return [self._sharded_data_tensor] def set_data_none(self): diff --git a/colossalai/legacy/zero/sharded_param/sharded_tensor.py b/colossalai/legacy/zero/sharded_param/sharded_tensor.py index 43c7576b93b5..262682d44645 100644 --- a/colossalai/legacy/zero/sharded_param/sharded_tensor.py +++ b/colossalai/legacy/zero/sharded_param/sharded_tensor.py @@ -4,7 +4,6 @@ class ShardedTensor(StatefulTensor): - def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: r""" A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. diff --git a/colossalai/logging/__init__.py b/colossalai/logging/__init__.py index 97fe4f89ded3..521eafa74c30 100644 --- a/colossalai/logging/__init__.py +++ b/colossalai/logging/__init__.py @@ -3,23 +3,23 @@ from .logger import DistributedLogger -__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers'] +__all__ = ["get_dist_logger", "DistributedLogger", "disable_existing_loggers"] -def get_dist_logger(name: str = 'colossalai') -> DistributedLogger: +def get_dist_logger(name: str = "colossalai") -> DistributedLogger: """Get logger instance based on name. The DistributedLogger will create singleton instances, which means that only one logger instance is created per name. Args: name (str): name of the logger, name must be unique - + Returns: :class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance. """ return DistributedLogger.get_instance(name=name) -def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None: +def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ["colossalai"]) -> None: """Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai". Args: diff --git a/colossalai/logging/logger.py b/colossalai/logging/logger.py index fd05ddf1d50f..eb5f28e2a3cf 100644 --- a/colossalai/logging/logger.py +++ b/colossalai/logging/logger.py @@ -42,12 +42,14 @@ def get_instance(name: str): def __init__(self, name): if name in DistributedLogger.__instances: raise Exception( - 'Logger with the same name has been created, you should use colossalai.logging.get_dist_logger') + "Logger with the same name has been created, you should use colossalai.logging.get_dist_logger" + ) else: handler = None - formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + formatter = logging.Formatter("colossalai - %(name)s - %(levelname)s: %(message)s") try: from rich.logging import RichHandler + handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) handler.setFormatter(formatter) except ImportError: @@ -79,7 +81,7 @@ def __get_call_info(): @staticmethod def _check_valid_logging_level(level: str): - assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level' + assert level in ["INFO", "DEBUG", "WARNING", "ERROR"], "found invalid logging level" def set_level(self, level: str) -> None: """Set the logging level @@ -90,7 +92,7 @@ def set_level(self, level: str) -> None: self._check_valid_logging_level(level) self._logger.setLevel(getattr(logging, level)) - def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None: + def log_to_file(self, path: Union[str, Path], mode: str = "a", level: str = "INFO", suffix: str = None) -> None: """Save the logs to file Args: @@ -99,8 +101,7 @@ def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INF level (str): Can only be INFO, DEBUG, WARNING and ERROR. suffix (str): The suffix string of log's name. """ - assert isinstance(path, (str, Path)), \ - f'expected argument path to be type str or Path, but got {type(path)}' + assert isinstance(path, (str, Path)), f"expected argument path to be type str or Path, but got {type(path)}" self._check_valid_logging_level(level) if isinstance(path, str): @@ -110,15 +111,15 @@ def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INF path.mkdir(parents=True, exist_ok=True) if suffix is not None: - log_file_name = f'rank_{self.rank}_{suffix}.log' + log_file_name = f"rank_{self.rank}_{suffix}.log" else: - log_file_name = f'rank_{self.rank}.log' + log_file_name = f"rank_{self.rank}.log" path = path.joinpath(log_file_name) # add file handler file_handler = logging.FileHandler(path, mode) file_handler.setLevel(getattr(logging, level)) - formatter = logging.Formatter('colossalai - %(name)s - %(levelname)s: %(message)s') + formatter = logging.Formatter("colossalai - %(name)s - %(levelname)s: %(message)s") file_handler.setFormatter(formatter) self._logger.addHandler(file_handler) @@ -137,8 +138,8 @@ def info(self, message: str, ranks: List[int] = None) -> None: ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('info', message_prefix, ranks) - self._log('info', message, ranks) + self._log("info", message_prefix, ranks) + self._log("info", message, ranks) def warning(self, message: str, ranks: List[int] = None) -> None: """Log a warning message. @@ -148,8 +149,8 @@ def warning(self, message: str, ranks: List[int] = None) -> None: ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('warning', message_prefix, ranks) - self._log('warning', message, ranks) + self._log("warning", message_prefix, ranks) + self._log("warning", message, ranks) def debug(self, message: str, ranks: List[int] = None) -> None: """Log a debug message. @@ -159,8 +160,8 @@ def debug(self, message: str, ranks: List[int] = None) -> None: ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('debug', message_prefix, ranks) - self._log('debug', message, ranks) + self._log("debug", message_prefix, ranks) + self._log("debug", message, ranks) def error(self, message: str, ranks: List[int] = None) -> None: """Log an error message. @@ -170,5 +171,5 @@ def error(self, message: str, ranks: List[int] = None) -> None: ranks (List[int]): List of parallel ranks. """ message_prefix = "{}:{} {}".format(*self.__get_call_info()) - self._log('error', message_prefix, ranks) - self._log('error', message, ranks) + self._log("error", message_prefix, ranks) + self._log("error", message, ranks) diff --git a/colossalai/nn/init.py b/colossalai/nn/init.py index 559b7038fc35..2637aa8eaaeb 100644 --- a/colossalai/nn/init.py +++ b/colossalai/nn/init.py @@ -1,8 +1,8 @@ import math import warnings -from torch import Tensor import torch.nn as nn +from torch import Tensor def zeros_(): @@ -23,7 +23,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def uniform_(a: float = 0., b: float = 1.): +def uniform_(a: float = 0.0, b: float = 1.0): r"""Return the initializer filling the input Tensor with values drawn from the uniform distribution :math:`\mathcal{U}(a, b)`. @@ -38,7 +38,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def normal_(mean: float = 0., std: float = 1.): +def normal_(mean: float = 0.0, std: float = 1.0): r"""Return the initializer filling the input Tensor with values drawn from the normal distribution .. math:: @@ -47,7 +47,7 @@ def normal_(mean: float = 0., std: float = 1.): Args: mean (float): the mean of the normal distribution. Defaults 0.0. std (float): the standard deviation of the normal distribution. Defaults 1.0. - """ + """ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return nn.init.normal_(tensor, mean, std) @@ -55,7 +55,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.): +def trunc_normal_(mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0): r"""Return the initializer filling the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` @@ -76,7 +76,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): +def kaiming_uniform_(a=0, mode="fan_in", nonlinearity="leaky_relu"): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a @@ -104,23 +104,23 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): warnings.warn("Initializing zero-element tensors is a no-op") return tensor - if mode == 'fan_in': - assert fan_in is not None, 'Fan_in is not provided.' + if mode == "fan_in": + assert fan_in is not None, "Fan_in is not provided." fan = fan_in - elif mode == 'fan_out': - assert fan_out is not None, 'Fan_out is not provided.' + elif mode == "fan_out": + assert fan_out is not None, "Fan_out is not provided." fan = fan_out else: - raise ValueError(f'Invalid initialization mode \'{mode}\'') + raise ValueError(f"Invalid initialization mode '{mode}'") std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) - bound = math.sqrt(3.) * std + bound = math.sqrt(3.0) * std return nn.init.uniform_(tensor, -bound, bound) return initializer -def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): +def kaiming_normal_(a=0, mode="fan_in", nonlinearity="leaky_relu"): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a @@ -148,14 +148,14 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): warnings.warn("Initializing zero-element tensors is a no-op") return tensor - if mode == 'fan_in': - assert fan_in is not None, 'Fan_in is not provided.' + if mode == "fan_in": + assert fan_in is not None, "Fan_in is not provided." fan = fan_in - elif mode == 'fan_out': - assert fan_out is not None, 'Fan_out is not provided.' + elif mode == "fan_out": + assert fan_out is not None, "Fan_out is not provided." fan = fan_out else: - raise ValueError(f'Invalid initialization mode \'{mode}\'') + raise ValueError(f"Invalid initialization mode '{mode}'") std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan) return nn.init.normal_(tensor, 0, std) @@ -163,7 +163,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.): +def xavier_uniform_(a: float = math.sqrt(3.0), scale: float = 2.0, gain: float = 1.0): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform @@ -184,7 +184,7 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1 # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." fan = fan_in if fan_out is not None: @@ -197,7 +197,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): return initializer -def xavier_normal_(scale: float = 2., gain: float = 1.): +def xavier_normal_(scale: float = 2.0, gain: float = 1.0): r"""Return the initializer filling the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal @@ -216,7 +216,7 @@ def xavier_normal_(scale: float = 2., gain: float = 1.): # adapted from torch.nn.init def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." fan = fan_in if fan_out is not None: @@ -224,7 +224,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): std = gain * math.sqrt(scale / float(fan)) - return nn.init.normal_(tensor, 0., std) + return nn.init.normal_(tensor, 0.0, std) return initializer @@ -232,7 +232,7 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def lecun_uniform_(): # adapted from jax.nn.initializers def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." var = 1.0 / fan_in bound = math.sqrt(3 * var) @@ -244,9 +244,9 @@ def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def lecun_normal_(): # adapted from jax.nn.initializers def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): - assert fan_in is not None, 'Fan_in is not provided.' + assert fan_in is not None, "Fan_in is not provided." std = math.sqrt(1.0 / fan_in) - return nn.init.trunc_normal_(tensor, std=std / .87962566103423978) + return nn.init.trunc_normal_(tensor, std=std / 0.87962566103423978) return initializer diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 05333fe965f1..6a5ccff510be 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -5,6 +5,17 @@ from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ - 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model' + "Experts", + "FFNExperts", + "TPExperts", + "Top1Router", + "Top2Router", + "MoeLayer", + "NormalNoiseGenerator", + "UniformNoiseGenerator", + "build_ffn_experts", + "MoeModule", + "MoeRouter", + "save_moe_model", + "load_moe_model", ] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py index 37f31c16709b..2f0b7e43673a 100644 --- a/colossalai/nn/layer/moe/_operation.py +++ b/colossalai/nn/layer/moe/_operation.py @@ -18,18 +18,18 @@ def build_moe_if_not_prebuilt(): global moe if moe is None: from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() class AllGather(torch.autograd.Function): - @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - global moe if moe is None: from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() if ctx is not None: @@ -51,7 +51,6 @@ def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: class ReduceScatter(torch.autograd.Function): - @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: if ctx is not None: @@ -98,7 +97,6 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: class MoeDispatch(torch.autograd.Function): - @staticmethod def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) @@ -124,7 +122,6 @@ def backward(ctx, output_grad): class MoeCombine(torch.autograd.Function): - @staticmethod def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): assert logits.dtype == torch.float32 @@ -137,7 +134,7 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): # load moe kernel during runtime if not pre-built build_moe_if_not_prebuilt() - fp16_flag = (expert_tokens.dtype == torch.float16) + fp16_flag = expert_tokens.dtype == torch.float16 cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) output = ctokens.to(torch.float16) if fp16_flag else ctokens @@ -155,8 +152,7 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): def backward(ctx, tokens_grad): expert_tokens, logits, mask, dest_idx = ctx.saved_tensors - cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ - else tokens_grad + cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx) d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py index efda1f22252d..adad19d581ef 100644 --- a/colossalai/nn/layer/moe/checkpoint.py +++ b/colossalai/nn/layer/moe/checkpoint.py @@ -16,7 +16,7 @@ def load_moe_model(model: nn.Module, load_path: str): state_dict = torch.load(load_path) for prefix, module in model.named_modules(): - if prefix.endswith('.moe_layer.experts'): + if prefix.endswith(".moe_layer.experts"): # this module should be an Experts instance assert isinstance(module, MoeExperts) @@ -25,16 +25,16 @@ def load_moe_model(model: nn.Module, load_path: str): for i in range(num_local): expert_id = ep_rank * num_local + i for name, _ in module.experts[i].named_parameters(): - cur_key = f'{prefix}.experts.{i}.{name}' - param_key = f'{prefix}.experts.{expert_id}.{name}' + cur_key = f"{prefix}.experts.{i}.{name}" + param_key = f"{prefix}.experts.{expert_id}.{name}" load_param = state_dict[param_key] state_dict[cur_key] = load_param for name, _ in module.experts[0].named_parameters(): - pop_pre = f'{prefix}.experts.' - pop_suf = f'.{name}' + pop_pre = f"{prefix}.experts." + pop_suf = f".{name}" for i in range(num_local, module.num_total_experts): - pop_key = f'{pop_pre}{i}{pop_suf}' + pop_key = f"{pop_pre}{i}{pop_suf}" state_dict.pop(pop_key) model.load_state_dict(state_dict) diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 712d872bb921..4b2ecb241702 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -20,8 +20,10 @@ class MoeExperts(nn.Module): def __init__(self, comm_name: str, num_experts: int): super().__init__() - assert comm_name in {"all_to_all", "all_gather"}, \ - "This kind of communication has not been implemented yet.\n Please use Experts build function." + assert comm_name in { + "all_to_all", + "all_gather", + }, "This kind of communication has not been implemented yet.\n Please use Experts build function." self.comm_name = comm_name self.num_total_experts = num_experts # Get the configuration of experts' deployment and parallel information from moe context @@ -50,7 +52,7 @@ def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args) # Attach parallel information for all parameters in Experts for exp in self.experts: for param in exp.parameters(): - param.__setattr__('moe_info', self.dist_info) + param.__setattr__("moe_info", self.dist_info) def forward(self, inputs: torch.Tensor): # Split inputs for each expert @@ -65,7 +67,7 @@ def forward(self, inputs: torch.Tensor): output = torch.cat(expert_output, dim=1).contiguous() return output - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): assert keep_vars == False, "Only support keep_vars=False now" dp_rank = dist.get_rank(self.dist_info.dp_group) ep_rank = dist.get_rank(self.dist_info.ep_group) @@ -79,11 +81,11 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): example_submodule = subm if dp_rank == 0: - local_prefix = prefix + 'experts.' + local_prefix = prefix + "experts." buffer_module = deepcopy(example_submodule) for i in range(self.num_total_experts): source_rank = i // self.num_local_experts - current_prefix = local_prefix + str(i) + '.' + current_prefix = local_prefix + str(i) + "." comm_module = submodule_dict.get(i, buffer_module) for name, param in comm_module.named_parameters(): dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group) @@ -94,8 +96,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): class FFNExperts(MoeExperts): - """Use torch.bmm to speed up for multiple experts. - """ + """Use torch.bmm to speed up for multiple experts.""" def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): super().__init__("all_to_all", num_experts) @@ -119,10 +120,9 @@ def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, d self.drop = nn.Dropout(p=drop_rate) for param in self.parameters(): - param.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, el, c, h] + param.__setattr__("moe_info", self.dist_info) + def forward(self, inputs): # inputs [g, el, c, h] el = inputs.size(1) h = inputs.size(-1) @@ -137,7 +137,7 @@ def forward(self, inputs): # inputs [g, el, c, h] out_model = torch.baddbmm(self.b2, out_inter, self.w2) with seed(ParallelMode.TENSOR): - outputs = self.drop(out_model) # outputs [el, gc, h] + outputs = self.drop(out_model) # outputs [el, gc, h] outputs = outputs.reshape(inshape) outputs = outputs.transpose(0, 1).contiguous() @@ -153,8 +153,7 @@ class TPExperts(MoeExperts): def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): super().__init__("all_gather", MOE_CONTEXT.max_ep_size) - assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ - "d_ff should be divide by maximum expert parallel size" + assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size" p_ff = d_ff // MOE_CONTEXT.max_ep_size @@ -177,12 +176,11 @@ def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, d self.act = nn.GELU() if activation is None else activation self.drop = nn.Dropout(p=drop_rate) - self.w1.__setattr__('moe_info', self.dist_info) - self.w2.__setattr__('moe_info', self.dist_info) - self.b1.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, e, c, h] + self.w1.__setattr__("moe_info", self.dist_info) + self.w2.__setattr__("moe_info", self.dist_info) + self.b1.__setattr__("moe_info", self.dist_info) + def forward(self, inputs): # inputs [g, e, c, h] e = inputs.size(1) h = inputs.size(-1) @@ -196,8 +194,8 @@ def forward(self, inputs): # inputs [g, e, c, h] out_inter = self.drop(out_act) out_model = torch.baddbmm(self.b2, out_inter, self.w2) - outputs = self.drop(out_model) # outputs [e, gc, h] + outputs = self.drop(out_model) # outputs [e, gc, h] outputs = outputs.reshape(inshape) outputs = outputs.transpose(0, 1).contiguous() - return outputs # outputs [g, e, c, h] + return outputs # outputs [g, e, c, h] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 9293d3208f11..23d483e6a17a 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -89,8 +89,9 @@ def forward(self, inputs: torch.Tensor) -> Tuple: elif self.experts.comm_name == "all_gather": expert_output = self.tp_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " - "build function.") + raise NotImplementedError( + "This kind of communication has not been implemented yet.\n Please use Experts " "build function." + ) # expert_output [e, c, h] if self.use_kernel: expert_output = expert_output.reshape(-1, self.d_model) @@ -135,27 +136,29 @@ class MoeModule(nn.Module): https://arxiv.org/abs/2201.05596 """ - def __init__(self, - dim_model: int, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - use_residual: bool = False, - residual_instance: Optional[nn.Module] = None, - expert_instance: Optional[MoeExperts] = None, - expert_cls: Optional[Type[nn.Module]] = None, - **expert_args): + def __init__( + self, + dim_model: int, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + use_residual: bool = False, + residual_instance: Optional[nn.Module] = None, + expert_instance: Optional[MoeExperts] = None, + expert_cls: Optional[Type[nn.Module]] = None, + **expert_args, + ): super().__init__() noisy_func = None if noisy_policy is not None: - if noisy_policy == 'Jitter': + if noisy_policy == "Jitter": noisy_func = UniformNoiseGenerator() - elif noisy_policy == 'Gaussian': + elif noisy_policy == "Gaussian": noisy_func = NormalNoiseGenerator(num_experts) else: raise NotImplementedError("Unsupported input noisy policy") @@ -167,18 +170,19 @@ def __init__(self, else: raise NotImplementedError("top_k > 2 is not supported yet") - self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + self.moe_router = moe_router_cls( + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) self.use_residual = use_residual if use_residual: if residual_instance is not None: self.residual_module = residual_instance else: - assert expert_cls is not None, \ - "Expert class can't be None when residual instance is not given" + assert expert_cls is not None, "Expert class can't be None when residual instance is not given" self.residual_module = expert_cls(**expert_args) with no_shard_zero_context(): @@ -187,14 +191,12 @@ def __init__(self, if expert_instance is not None: my_experts = expert_instance else: - assert expert_cls is not None, \ - "Expert class can't be None when experts instance is not given" + assert expert_cls is not None, "Expert class can't be None when experts instance is not given" my_experts = Experts(expert_cls, num_experts, **expert_args) - self.moe_layer = MoeLayer(dim_model=dim_model, - num_experts=num_experts, - router=self.moe_router, - experts=my_experts) + self.moe_layer = MoeLayer( + dim_model=dim_model, num_experts=num_experts, router=self.moe_router, experts=my_experts + ) def forward(self, inputs: torch.Tensor): moe_output, l_aux = self.moe_layer(inputs) diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index c5b8390bf047..7ba83b2787a0 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -1,226 +1,235 @@ -import math -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist -from colossalai.utils import get_current_device -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import moe_cumsum -from typing import Callable, Optional -from torch.distributed import ProcessGroup - - -class MoeRouter(nn.Module, ABC): - """Base class for all MoE routers. - Args: - k_value (int): The value of top_k. - capacity_factor_train (float): Capacity factor in routing of training. - capacity_factor_eval (float): Capacity factor in routing of evaluation. - min_capacity (int): The minimum number of the capacity of each expert. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__() - self.k_value = k_value - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.noisy_func = noisy_func - self.drop_tks = drop_tks - self._routing_loss = None - - def get_capacity(self, logits_shape): - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return capacity - - def set_routing_loss(self, aux_loss: torch.Tensor) -> None: - assert self._routing_loss is None - self._routing_loss = aux_loss - - def pop_routing_loss(self) -> torch.Tensor: - assert self._routing_loss is not None - reservation = self._routing_loss - self._routing_loss = None - return reservation - - -class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about Switch Transformer - of Google. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert. - select_policy (str, optional): The policy about tokens selection. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - self.select_policy = select_policy - assert select_policy in {"first", "random"} - if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, - device=get_current_device())).rsample - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(inputs, dim=-1) - mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - if self.select_policy == "random": - rand_mask = mask * self.uniform(mask.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) - mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) - elif self.select_policy == "first": - ranks = moe_cumsum(mask) - mask = mask * torch.lt(ranks, capacity) - else: - raise NotImplementedError("Not support such select policy yet.") - - ranks = torch.sum(mask * ranks, dim=-1) - - if use_kernel: - mask = torch.sum(mask, dim=-1) - mask = torch.stack([mask], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity - else: - ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * logits.type_as(inputs) - combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) - sec_mask = combine_weights.bool() - return combine_weights, sec_mask - - -class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More detailed function can be found in the paper about ViT-MoE. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation. - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - # inputs: [s, h] - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) # logits: [s, e] - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(logits, dim=-1) - mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) - top2_idx = torch.argmax(logits_except1, dim=-1) - mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - - cmask = (mask1 + mask2) # loss: [s, e] - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) - rank2 += torch.sum(mask1, dim=-2, keepdim=True) - - mask1 *= torch.lt(rank1, capacity) - mask2 *= torch.lt(rank2, capacity) - - rank1 = torch.sum(mask1 * rank1, dim=-1) - rank2 = torch.sum(mask2 * rank2, dim=-1) - - if use_kernel: - mask1 = torch.sum(mask1, dim=-1) - mask2 = torch.sum(mask2, dim=-1) - - mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - - return logits, mask, dest_idx, num_experts * capacity - else: - weight1 = mask1 * logits.type_as(inputs) - weight2 = mask2 * logits.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() - - return cb_weight, sec_mask +import math +from abc import ABC +from typing import Callable, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import ProcessGroup + +from colossalai.nn.layer.moe._operation import moe_cumsum +from colossalai.utils import get_current_device + + +class MoeRouter(nn.Module, ABC): + """Base class for all MoE routers. + Args: + k_value (int): The value of top_k. + capacity_factor_train (float): Capacity factor in routing of training. + capacity_factor_eval (float): Capacity factor in routing of evaluation. + min_capacity (int): The minimum number of the capacity of each expert. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__( + self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Callable = None, + drop_tks: bool = True, + ): + super().__init__() + self.k_value = k_value + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.drop_tks = drop_tks + self._routing_loss = None + + def get_capacity(self, logits_shape): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 + return capacity + + def set_routing_loss(self, aux_loss: torch.Tensor) -> None: + assert self._routing_loss is None + self._routing_loss = aux_loss + + def pop_routing_loss(self) -> torch.Tensor: + assert self._routing_loss is not None + reservation = self._routing_loss + self._routing_loss = None + return reservation + + +class Top1Router(MoeRouter): + """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about Switch Transformer + of Google. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert. + select_policy (str, optional): The policy about tokens selection. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) + self.select_policy = select_policy + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device()) + ).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(inputs, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + if self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask) + elif self.select_policy == "first": + ranks = moe_cumsum(mask) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") + + ranks = torch.sum(mask * ranks, dim=-1) + + if use_kernel: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return logits, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * logits.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask + + +class Top2Router(MoeRouter): + """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More detailed function can be found in the paper about ViT-MoE. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation. + """ + + def __init__( + self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True, + ): + super().__init__( + k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + # inputs: [s, h] + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) # logits: [s, e] + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(logits, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) + + cmask = mask1 + mask2 # loss: [s, e] + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank2 = moe_cumsum(mask2) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) + + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) + + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) + + if use_kernel: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) + + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) + + return logits, mask, dest_idx, num_experts * capacity + else: + weight1 = mask1 * logits.type_as(inputs) + weight2 = mask2 * logits.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) + + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + + return cb_weight, sec_mask diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 4ca8bd703386..4f31dd5579dc 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,68 +1,71 @@ -import torch -import torch.nn.functional as F -from colossalai.utils import get_current_device -from colossalai.context.moe_context import MOE_CONTEXT -from .experts import FFNExperts, TPExperts - - -class ForceFP32Parameter(torch.nn.Parameter): - - def half(self, memory_format=None): - return self.data.clone() - - -class NormalNoiseGenerator: - """Generates a random noisy mask for logits tensor. - - All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where - `E = the number of experts`. - - Args: - num_experts (int): The number of experts. - """ - - def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.normal(inputs.shape) - return inputs + noisy - - -class UniformNoiseGenerator: - """Generates a random noisy mask for logits tensor. - copied from mesh tensorflow: - Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. - Makes models more resilient to rounding errors introduced by bfloat16. - This seems particularly important for logits. - - Args: - eps (float, optional): Epsilon in generator, defaults 1e-2. - """ - - def __init__(self, eps: float = 1e-2): - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.uniform(inputs.shape) - return inputs * noisy - - -def autocast_softmax(logit: torch.Tensor, dim: int): - if logit.dtype != torch.float32: - logit = logit.float() - return F.softmax(logit, dim=dim) - - -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") +import torch +import torch.nn.functional as F + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils import get_current_device + +from .experts import FFNExperts, TPExperts + + +class ForceFP32Parameter(torch.nn.Parameter): + def half(self, memory_format=None): + return self.data.clone() + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logits tensor. + + All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where + `E = the number of experts`. + + Args: + num_experts (int): The number of experts. + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal( + loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +class UniformNoiseGenerator: + """Generates a random noisy mask for logits tensor. + copied from mesh tensorflow: + Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + + Args: + eps (float, optional): Epsilon in generator, defaults 1e-2. + """ + + def __init__(self, eps: float = 1e-2): + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, device=get_current_device()), + ).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.uniform(inputs.shape) + return inputs * noisy + + +def autocast_softmax(logit: torch.Tensor, dim: int): + if logit.dtype != torch.float32: + logit = logit.float() + return F.softmax(logit, dim=dim) + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: + return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % mep_size == 0: + return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/nn/layer/utils.py b/colossalai/nn/layer/utils.py index dc12ff8daa4e..ff9b5c8f2b5b 100644 --- a/colossalai/nn/layer/utils.py +++ b/colossalai/nn/layer/utils.py @@ -8,7 +8,6 @@ def divide(numerator, denominator): Returns: int: the result of exact division. """ - assert denominator != 0, 'denominator can not be zero' - assert numerator % denominator == 0, \ - '{} is not divisible by {}'.format(numerator, denominator) + assert denominator != 0, "denominator can not be zero" + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) return numerator // denominator diff --git a/colossalai/nn/lr_scheduler/__init__.py b/colossalai/nn/lr_scheduler/__init__.py index 34731ee901a0..783f12f8c7c4 100644 --- a/colossalai/nn/lr_scheduler/__init__.py +++ b/colossalai/nn/lr_scheduler/__init__.py @@ -3,10 +3,21 @@ from .multistep import MultiStepLR, MultiStepWarmupLR from .onecycle import OneCycleLR from .poly import PolynomialLR, PolynomialWarmupLR -from .torch import LambdaLR, MultiplicativeLR, StepLR, ExponentialLR +from .torch import ExponentialLR, LambdaLR, MultiplicativeLR, StepLR __all__ = [ - 'CosineAnnealingLR', 'CosineAnnealingWarmupLR', 'FlatAnnealingLR', 'FlatAnnealingWarmupLR', 'LinearWarmupLR', - 'MultiStepLR', 'MultiStepWarmupLR', 'OneCycleLR', 'PolynomialLR', 'PolynomialWarmupLR', 'LambdaLR', - 'MultiplicativeLR', 'StepLR', 'ExponentialLR' + "CosineAnnealingLR", + "CosineAnnealingWarmupLR", + "FlatAnnealingLR", + "FlatAnnealingWarmupLR", + "LinearWarmupLR", + "MultiStepLR", + "MultiStepWarmupLR", + "OneCycleLR", + "PolynomialLR", + "PolynomialWarmupLR", + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "ExponentialLR", ] diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py index fb587e1a1341..a896d3acba6c 100644 --- a/colossalai/nn/lr_scheduler/cosine.py +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -58,11 +58,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0., last_epoch: int = -1): - base_scheduler = _CosineAnnealingLR(optimizer, - total_steps - warmup_steps, - eta_min=eta_min, - last_epoch=last_epoch) + def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: float = 0.0, last_epoch: int = -1): + base_scheduler = _CosineAnnealingLR( + optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch + ) super().__init__(optimizer, warmup_steps, base_scheduler) @@ -79,7 +78,7 @@ class FlatAnnealingLR(DelayerScheduler): def __init__(self, optimizer, total_steps: int, pct_start: float = 0.72, last_epoch: int = -1, **kwargs): if not (0.0 <= pct_start <= 1.0): - raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + raise ValueError(f"pct_start must >= 0.0 and <= 1.0, got {pct_start}") flat_steps = int(total_steps * pct_start) anneal_steps = total_steps - flat_steps base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps) @@ -100,16 +99,18 @@ class FlatAnnealingWarmupLR(WarmupDelayerScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - warmup_steps: int = 0, - pct_start: float = 0.72, - eta_min: int = 0, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + pct_start: float = 0.72, + eta_min: int = 0, + last_epoch: int = -1, + **kwargs, + ): if not (0.0 <= pct_start <= 1.0): - raise ValueError(f'pct_start must >= 0.0 and <= 1.0, got {pct_start}') + raise ValueError(f"pct_start must >= 0.0 and <= 1.0, got {pct_start}") flat_steps = int((total_steps - warmup_steps) * pct_start) anneal_steps = total_steps - warmup_steps - flat_steps base_scheduler = _CosineAnnealingLR(optimizer, anneal_steps, eta_min=eta_min) diff --git a/colossalai/nn/lr_scheduler/delayed.py b/colossalai/nn/lr_scheduler/delayed.py index a73ff8ae37ac..ce7f126d6101 100644 --- a/colossalai/nn/lr_scheduler/delayed.py +++ b/colossalai/nn/lr_scheduler/delayed.py @@ -2,7 +2,6 @@ class _enable_get_lr_call: - def __init__(self, o): self.o = o @@ -28,18 +27,18 @@ class DelayerScheduler(_LRScheduler): def __init__(self, optimizer, delay_epochs, after_scheduler, last_epoch=-1): if delay_epochs < 0: - raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") self.delay_epochs = delay_epochs self.after_scheduler = after_scheduler self.finished = False super().__init__(optimizer, last_epoch) def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} - if isinstance(state_dict['after_scheduler'], _LRScheduler): - state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ - state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() - del state_dict['after_scheduler'] + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] else: raise NotImplementedError() return state_dict @@ -85,11 +84,11 @@ def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1): super().__init__(optimizer, last_epoch) def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} - if isinstance(state_dict['after_scheduler'], _LRScheduler): - state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ - state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() - del state_dict['after_scheduler'] + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] else: raise NotImplementedError() return state_dict @@ -130,9 +129,9 @@ class WarmupDelayerScheduler(_LRScheduler): def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last_epoch=-1): if delay_epochs < 0: - raise ValueError(f'delay_epochs must >= 0, got {delay_epochs}') + raise ValueError(f"delay_epochs must >= 0, got {delay_epochs}") if warmup_epochs < 0: - raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}') + raise ValueError(f"warmup_epochs must >= 0, got {warmup_epochs}") self.warmup_epochs = warmup_epochs self.delay_epochs = delay_epochs self.after_scheduler = after_scheduler @@ -140,11 +139,11 @@ def __init__(self, optimizer, warmup_epochs, delay_epochs, after_scheduler, last super().__init__(optimizer, last_epoch) def state_dict(self): - state_dict = {key: value for key, value in self.__dict__.items() if key not in 'optimizer'} - if isinstance(state_dict['after_scheduler'], _LRScheduler): - state_dict['after_scheduler_type'] = type(state_dict['after_scheduler']).__name__ - state_dict['after_scheduler_dict'] = state_dict['after_scheduler'].state_dict() - del state_dict['after_scheduler'] + state_dict = {key: value for key, value in self.__dict__.items() if key not in "optimizer"} + if isinstance(state_dict["after_scheduler"], _LRScheduler): + state_dict["after_scheduler_type"] = type(state_dict["after_scheduler"]).__name__ + state_dict["after_scheduler_dict"] = state_dict["after_scheduler"].state_dict() + del state_dict["after_scheduler"] else: raise NotImplementedError() return state_dict @@ -155,7 +154,7 @@ def get_lr(self): self.after_scheduler.base_lrs = self.base_lrs # reset lr to base_lr for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs): - group['lr'] = base_lr + group["lr"] = base_lr self.finished = True with _enable_get_lr_call(self.after_scheduler): return self.after_scheduler.get_lr() diff --git a/colossalai/nn/lr_scheduler/linear.py b/colossalai/nn/lr_scheduler/linear.py index 21a865e4c12b..1251c261d51f 100644 --- a/colossalai/nn/lr_scheduler/linear.py +++ b/colossalai/nn/lr_scheduler/linear.py @@ -21,5 +21,7 @@ def get_lr(self): if self.last_epoch < self.warmup_steps: return [(self.last_epoch + 1) / (self.warmup_steps + 1) * lr for lr in self.base_lrs] else: - return [(self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr - for lr in self.base_lrs] + return [ + (self.total_steps - self.last_epoch) / (self.total_steps - self.warmup_steps) * lr + for lr in self.base_lrs + ] diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py index c428c911c94d..86589d74662d 100644 --- a/colossalai/nn/lr_scheduler/multistep.py +++ b/colossalai/nn/lr_scheduler/multistep.py @@ -20,13 +20,15 @@ class MultiStepLR(_MultiStepLR): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - milestones: List[int] = None, - gamma: float = 0.1, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs, + ): super().__init__(optimizer, milestones, gamma=gamma, last_epoch=last_epoch) @@ -44,16 +46,18 @@ class MultiStepWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - warmup_steps: int = 0, - milestones: List[int] = None, - gamma: float = 0.1, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + milestones: List[int] = None, + gamma: float = 0.1, + last_epoch: int = -1, + **kwargs, + ): if len(milestones) == 0: - raise ValueError('milestones cannot be empty') + raise ValueError("milestones cannot be empty") milestones = [v - warmup_steps for v in milestones if v >= warmup_steps] base_scheduler = _MultiStepLR(optimizer, milestones=milestones, gamma=gamma) super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/lr_scheduler/onecycle.py b/colossalai/nn/lr_scheduler/onecycle.py index 6835b3ee1cf2..a8e551526dbd 100644 --- a/colossalai/nn/lr_scheduler/onecycle.py +++ b/colossalai/nn/lr_scheduler/onecycle.py @@ -65,27 +65,31 @@ class OneCycleLR(_OneCycleLR): https://arxiv.org/abs/1708.07120 """ - def __init__(self, - optimizer, - total_steps: int, - pct_start=0.3, - anneal_strategy='cos', - cycle_momentum=True, - base_momentum=0.85, - max_momentum=0.95, - div_factor=25.0, - final_div_factor=10000.0, - last_epoch=-1, - **kwargs): - max_lrs = list(map(lambda group: group['lr'], optimizer.param_groups)) - super().__init__(optimizer, - max_lrs, - total_steps=total_steps, - pct_start=pct_start, - anneal_strategy=anneal_strategy, - cycle_momentum=cycle_momentum, - base_momentum=base_momentum, - max_momentum=max_momentum, - div_factor=div_factor, - final_div_factor=final_div_factor, - last_epoch=last_epoch) + def __init__( + self, + optimizer, + total_steps: int, + pct_start=0.3, + anneal_strategy="cos", + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=10000.0, + last_epoch=-1, + **kwargs, + ): + max_lrs = list(map(lambda group: group["lr"], optimizer.param_groups)) + super().__init__( + optimizer, + max_lrs, + total_steps=total_steps, + pct_start=pct_start, + anneal_strategy=anneal_strategy, + cycle_momentum=cycle_momentum, + base_momentum=base_momentum, + max_momentum=max_momentum, + div_factor=div_factor, + final_div_factor=final_div_factor, + last_epoch=last_epoch, + ) diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py index 4f2249720ef6..4a3814461ea9 100644 --- a/colossalai/nn/lr_scheduler/poly.py +++ b/colossalai/nn/lr_scheduler/poly.py @@ -15,15 +15,11 @@ class PolynomialLR(_LRScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - end_lr: float = 0.0001, - power: float = 1.0, - last_epoch: int = -1, - **kwargs): + def __init__( + self, optimizer, total_steps: int, end_lr: float = 0.0001, power: float = 1.0, last_epoch: int = -1, **kwargs + ): if end_lr < 0: - raise ValueError(f'end_lr must >= 0, got {end_lr}') + raise ValueError(f"end_lr must >= 0, got {end_lr}") self.total_steps = total_steps self.end_lr = end_lr self.power = power @@ -33,9 +29,11 @@ def get_lr(self): return self._get_closed_form_lr() def _get_closed_form_lr(self): - return [(base_lr - self.end_lr) * - ((1 - min(self.last_epoch, self.total_steps) / self.total_steps)**self.power) + self.end_lr - for base_lr in self.base_lrs] + return [ + (base_lr - self.end_lr) * ((1 - min(self.last_epoch, self.total_steps) / self.total_steps) ** self.power) + + self.end_lr + for base_lr in self.base_lrs + ] class PolynomialWarmupLR(WarmupScheduler): @@ -51,13 +49,15 @@ class PolynomialWarmupLR(WarmupScheduler): the schedule is started from the beginning or When last_epoch=-1, sets initial lr as lr. """ - def __init__(self, - optimizer, - total_steps: int, - warmup_steps: int = 0, - end_lr: float = 0.0001, - power: float = 1.0, - last_epoch: int = -1, - **kwargs): + def __init__( + self, + optimizer, + total_steps: int, + warmup_steps: int = 0, + end_lr: float = 0.0001, + power: float = 1.0, + last_epoch: int = -1, + **kwargs, + ): base_scheduler = PolynomialLR(optimizer, total_steps - warmup_steps, end_lr=end_lr, power=power) super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch) diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md index d839753d6c44..c4afc6128d43 100644 --- a/colossalai/nn/optimizer/README.md +++ b/colossalai/nn/optimizer/README.md @@ -3,7 +3,7 @@ ## Introduction Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI), -which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc. diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 7e310793f515..26f152da20d3 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -6,4 +6,4 @@ from .lamb import Lamb from .lars import Lars -__all__ = ['FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam'] +__all__ = ["FusedLAMB", "FusedAdam", "FusedSGD", "Lamb", "Lars", "CPUAdam", "HybridAdam"] diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 9767fcb8b1e2..f35dc0200237 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -61,36 +61,39 @@ class CPUAdam(NVMeOptimizer): # Param weight, grad, momentum and variance num_fp32_shards_per_param = 4 - def __init__(self, - model_params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - adamw_mode=True, - nvme_offload_fraction: float = 0.0, - nvme_offload_dir: Optional[str] = None): - + def __init__( + self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None, + ): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode cpu_adam = CPUAdamBuilder().load() self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) - def torch_adam_update(self, - data, - grad, - exp_avg, - exp_avg_sq, - lr, - beta1, - beta2, - eps, - weight_decay, - bias_correction1, - bias_correction2, - use_adamw=False): + def torch_adam_update( + self, + data, + grad, + exp_avg, + exp_avg_sq, + lr, + beta1, + beta2, + eps, + weight_decay, + bias_correction1, + bias_correction2, + use_adamw=False, + ): grad = grad.to(data.dtype) if weight_decay != 0: @@ -117,10 +120,9 @@ def step(self, closure=None, div_scale: float = -1): with torch.enable_grad(): loss = closure() - self._pre_step('exp_avg', 'exp_avg_sq') + self._pre_step("exp_avg", "exp_avg_sq") for _, group in enumerate(self.param_groups): - for _, p in enumerate(group['params']): - + for _, p in enumerate(group["params"]): if p.grad is None: continue @@ -128,48 +130,81 @@ def step(self, closure=None, div_scale: float = -1): target_device = p.device if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # FIXME(ver217): CPU adam kernel only supports fp32 states now assert p.dtype is torch.float, "CPUAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, device=target_device) + state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) + state["exp_avg_sq"] = torch.zeros_like(p, device=target_device) self._post_state_init(p) - state['step'] += 1 - beta1, beta2 = group['betas'] + state["step"] += 1 + beta1, beta2 = group["betas"] - if target_device.type == 'cpu': + if target_device.type == "cpu": assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size" - assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" - assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" - self._pre_update(p, 'exp_avg', 'exp_avg_sq') + assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" + assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" + self._pre_update(p, "exp_avg", "exp_avg_sq") if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] - self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], - beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, - bias_correction2, self.adamw_mode) + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + self.torch_adam_update( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + self.adamw_mode, + ) else: - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], - group['weight_decay'], group['bias_correction'], p.data, p.grad.data, - state['exp_avg'], state['exp_avg_sq'], div_scale) - self._post_update(p, 'exp_avg', 'exp_avg_sq') - elif target_device.type == 'cuda': + self.cpu_adam_op.step( + state["step"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + group["bias_correction"], + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + div_scale, + ) + self._post_update(p, "exp_avg", "exp_avg_sq") + elif target_device.type == "cuda": assert div_scale == -1, "div_scale should remain default" - assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" - assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda" + assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda" - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] # adam on cuda - self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], - beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, - bias_correction2, self.adamw_mode) + self.torch_adam_update( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + self.adamw_mode, + ) else: raise RuntimeError self._post_step() diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 3a05a34f52d2..fcdd3257d700 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -1,11 +1,11 @@ # modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py -''' +""" Copyright 2020 The Microsoft DeepSpeed Team Copyright NVIDIA/apex This file is adapted from fused adam in NVIDIA/apex, commit a109f85 Licensed under the MIT License. -''' +""" import torch from colossalai.utils import multi_tensor_applier @@ -51,37 +51,39 @@ class FusedAdam(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - adamw_mode=True, - weight_decay=0., - amsgrad=False, - set_grad_none=True): - + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + adamw_mode=True, + weight_decay=0.0, + amsgrad=False, + set_grad_none=True, + ): if amsgrad: - raise RuntimeError('FusedAdam does not support the AMSGrad variant.') + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) super(FusedAdam, self).__init__(params, defaults) self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self.multi_tensor_adam = fused_optim.multi_tensor_adam else: - raise RuntimeError('FusedAdam requires cuda extensions') + raise RuntimeError("FusedAdam requires cuda extensions") def zero_grad(self, set_to_none=False): if set_to_none: for group in self.param_groups: - for p in group['params']: + for p in group["params"]: p.grad = None else: super(FusedAdam, self).zero_grad() @@ -97,51 +99,63 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no """ if any(p is not None for p in [grads, output_params, scale, grad_norms]): raise RuntimeError( - 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.' + "FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments." ) loss = None if closure is not None: loss = closure() for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 + if "step" in group: + group["step"] += 1 else: - group['step'] = 1 + group["step"] = 1 # create lists for multi-tensor apply g_l, p_l, m_l, v_l = [], [], [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: raise RuntimeError( - 'FusedAdam does not support sparse gradients, please consider SparseAdam instead') + "FusedAdam does not support sparse gradients, please consider SparseAdam instead" + ) state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]: - raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.') + raise RuntimeError("FusedAdam only support fp16, fp32 and bf16.") g_l.append(p.grad.data) p_l.append(p.data) - m_l.append(state['exp_avg']) - v_l.append(state['exp_avg_sq']) - - multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], - beta1, beta2, group['eps'], group['step'], self.adamw_mode, bias_correction, - group['weight_decay'], div_scale) + m_l.append(state["exp_avg"]) + v_l.append(state["exp_avg_sq"]) + + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_l, p_l, m_l, v_l], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adamw_mode, + bias_correction, + group["weight_decay"], + div_scale, + ) return loss diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index a2807d70f454..3e1d5a7ba539 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -49,41 +49,46 @@ class FusedLAMB(torch.optim.Optimizer): https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=0.01, - amsgrad=False, - adam_w_mode=True, - grad_averaging=True, - set_grad_none=True, - max_grad_norm=1.0, - use_nvlamb=False): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + amsgrad=False, + adam_w_mode=True, + grad_averaging=True, + set_grad_none=True, + max_grad_norm=1.0, + use_nvlamb=False, + ): if amsgrad: - raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') - defaults = dict(lr=lr, - bias_correction=bias_correction, - betas=betas, - eps=eps, - weight_decay=weight_decay, - grad_averaging=grad_averaging, - max_grad_norm=max_grad_norm) + raise RuntimeError("FusedLAMB does not support the AMSGrad variant.") + defaults = dict( + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm, + ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], - dtype=torch.int, - device=self.param_groups[0]["params"][0].device) + self._dummy_overflow_buf = torch.tensor( + [0], dtype=torch.int, device=self.param_groups[0]["params"][0].device + ) self.multi_tensor_lamb = fused_optim.multi_tensor_lamb else: - raise RuntimeError('FusedLAMB requires cuda extensions') + raise RuntimeError("FusedLAMB requires cuda extensions") self.adam_w_mode = 1 if adam_w_mode else 0 self.set_grad_none = set_grad_none @@ -92,7 +97,7 @@ def __init__(self, def zero_grad(self): if self.set_grad_none: for group in self.param_groups: - for p in group['params']: + for p in group["params"]: p.grad = None else: super(FusedLAMB, self).zero_grad() @@ -111,7 +116,7 @@ def step(self, closure=None): # create separate grad lists for fp32 and fp16 params g_all_32, g_all_16 = [], [] for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.dtype == torch.float32: @@ -119,7 +124,7 @@ def step(self, closure=None): elif p.dtype == torch.float16: g_all_16.append(p.grad.data) else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') + raise RuntimeError("FusedLAMB only support fp16 and fp32.") device = self.param_groups[0]["params"][0].device g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device) @@ -130,63 +135,91 @@ def step(self, closure=None): g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0] # blend two grad norms to get global grad norm - global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, - [[g_norm_32, g_norm_16]], False)[0] - max_grad_norm = self.defaults['max_grad_norm'] + global_grad_norm = multi_tensor_applier( + self.multi_tensor_l2norm, self._dummy_overflow_buf, [[g_norm_32, g_norm_16]], False + )[0] + max_grad_norm = self.defaults["max_grad_norm"] for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] - grad_averaging = 1 if group['grad_averaging'] else 0 + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + grad_averaging = 1 if group["grad_averaging"] else 0 # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 + if "step" in group: + group["step"] += 1 else: - group['step'] = 1 + group["step"] = 1 # create lists for multi-tensor apply g_16, p_16, m_16, v_16 = [], [], [], [] g_32, p_32, m_32, v_32 = [], [], [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: raise RuntimeError( - 'FusedLAMB does not support sparse gradients, please consider SparseAdam instead') + "FusedLAMB does not support sparse gradients, please consider SparseAdam instead" + ) state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) if p.dtype == torch.float16: g_16.append(p.grad.data) p_16.append(p.data) - m_16.append(state['exp_avg']) - v_16.append(state['exp_avg_sq']) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) elif p.dtype == torch.float32: g_32.append(p.grad.data) p_32.append(p.data) - m_32.append(state['exp_avg']) - v_32.append(state['exp_avg_sq']) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') - - if (len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16], - group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, - group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, - max_grad_norm, self.use_nvlamb) - if (len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32], - group['lr'], beta1, beta2, group['eps'], group['step'], bias_correction, - group['weight_decay'], grad_averaging, self.adam_w_mode, global_grad_norm, - max_grad_norm, self.use_nvlamb) + raise RuntimeError("FusedLAMB only support fp16 and fp32.") + + if len(g_16) > 0: + multi_tensor_applier( + self.multi_tensor_lamb, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + bias_correction, + group["weight_decay"], + grad_averaging, + self.adam_w_mode, + global_grad_norm, + max_grad_norm, + self.use_nvlamb, + ) + if len(g_32) > 0: + multi_tensor_applier( + self.multi_tensor_lamb, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + bias_correction, + group["weight_decay"], + grad_averaging, + self.adam_w_mode, + global_grad_norm, + max_grad_norm, + self.use_nvlamb, + ) return loss diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 59a93a8be9c7..95a6354208a8 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -54,14 +54,9 @@ class FusedSGD(Optimizer): The Nesterov version is analogously modified. """ - def __init__(self, - params, - lr=required, - momentum=0, - dampening=0, - weight_decay=0, - nesterov=False, - wd_after_momentum=False): + def __init__( + self, params, lr=required, momentum=0, dampening=0, weight_decay=0, nesterov=False, wd_after_momentum=False + ): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -78,20 +73,21 @@ def __init__(self, if multi_tensor_applier.available: from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], - dtype=torch.int, - device=self.param_groups[0]["params"][0].device) + self._dummy_overflow_buf = torch.tensor( + [0], dtype=torch.int, device=self.param_groups[0]["params"][0].device + ) self.multi_tensor_sgd = fused_optim.multi_tensor_sgd else: - raise RuntimeError('FusedSGD requires cuda extensions') + raise RuntimeError("FusedSGD requires cuda extensions") def __setstate__(self, state): super(FusedSGD, self).__setstate__(state) for group in self.param_groups: - group.setdefault('nesterov', False) + group.setdefault("nesterov", False) def get_momentums(self, params): momentums = [] @@ -101,13 +97,13 @@ def get_momentums(self, params): # torch.optim.SGD initializes momentum in the main loop, we have # to do it here, and track whether or not we've done so, so that # momentum application can be skipped in the main kernel. - if 'momentum_buffer' not in param_state: + if "momentum_buffer" not in param_state: first_run = True - buf = param_state['momentum_buffer'] = torch.zeros_like(p) + buf = param_state["momentum_buffer"] = torch.zeros_like(p) momentums.append(buf) else: first_run = False - momentums.append(param_state['momentum_buffer']) + momentums.append(param_state["momentum_buffer"]) return momentums, first_run def step(self, closure=None): @@ -122,10 +118,10 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - nesterov = group['nesterov'] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] # For each group, there are 3 possible combinations we need to consider: # grad_type, param_to_update_type, momentum_type @@ -133,15 +129,26 @@ def step(self, closure=None): # 2. fp32, fp32, fp32 # 3. fp16, fp32, fp32 g_l, p_l = [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: - raise RuntimeError('FusedSGD does not support sparse gradients') + raise RuntimeError("FusedSGD does not support sparse gradients") g_l.append(p.grad) p_l.append(p) m_l, first_run = self.get_momentums(p_l) - multi_tensor_applier(self.multi_tensor_sgd, self._dummy_overflow_buf, [g_l, p_l, m_l], weight_decay, - momentum, dampening, group['lr'], nesterov, first_run, self.wd_after_momentum, 1.0) + multi_tensor_applier( + self.multi_tensor_sgd, + self._dummy_overflow_buf, + [g_l, p_l, m_l], + weight_decay, + momentum, + dampening, + group["lr"], + nesterov, + first_run, + self.wd_after_momentum, + 1.0, + ) return loss diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index e08df410effe..32fc6136c4e6 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -1,7 +1,6 @@ from typing import Any, Optional import torch -from torch.optim import Adam from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.utils import multi_tensor_applier @@ -61,20 +60,30 @@ class HybridAdam(CPUAdam): # Param weight, grad, momentum and variance num_fp32_shards_per_param = 4 - def __init__(self, - model_params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - adamw_mode=True, - nvme_offload_fraction: float = 0.0, - nvme_offload_dir: Optional[str] = None, - **defaults: Any): - - super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction, - nvme_offload_dir) + def __init__( + self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + nvme_offload_fraction: float = 0.0, + nvme_offload_dir: Optional[str] = None, + **defaults: Any, + ): + super().__init__( + model_params, + lr, + bias_correction, + betas, + eps, + weight_decay, + adamw_mode, + nvme_offload_fraction, + nvme_offload_dir, + ) fused_optim = FusedOptimBuilder().load() self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -86,12 +95,11 @@ def step(self, closure=None, div_scale: float = -1): with torch.enable_grad(): loss = closure() - self._pre_step('exp_avg', 'exp_avg_sq') + self._pre_step("exp_avg", "exp_avg_sq") for _, group in enumerate(self.param_groups): g_l, p_l, m_l, v_l = [], [], [], [] group_step = 0 - for _, p in enumerate(group['params']): - + for _, p in enumerate(group["params"]): if p.grad is None: continue @@ -99,54 +107,87 @@ def step(self, closure=None, div_scale: float = -1): target_device = p.device if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # FIXME(ver217): CPU adam kernel only supports fp32 states now assert p.dtype is torch.float, "HybridAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, device=target_device) + state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) + state["exp_avg_sq"] = torch.zeros_like(p, device=target_device) self._post_state_init(p) - state['step'] += 1 - group_step = state['step'] - beta1, beta2 = group['betas'] + state["step"] += 1 + group_step = state["step"] + beta1, beta2 = group["betas"] - if target_device.type == 'cpu': - assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" - assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" - self._pre_update(p, 'exp_avg', 'exp_avg_sq') + if target_device.type == "cpu": + assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" + assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" + self._pre_update(p, "exp_avg", "exp_avg_sq") if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now - bias_correction1 = 1 - beta1**state['step'] - bias_correction2 = 1 - beta2**state['step'] - self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], - beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, - bias_correction2, self.adamw_mode) + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + self.torch_adam_update( + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + bias_correction1, + bias_correction2, + self.adamw_mode, + ) else: - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], - group['weight_decay'], group['bias_correction'], p.data, p.grad.data, - state['exp_avg'], state['exp_avg_sq'], div_scale) - self._post_update(p, 'exp_avg', 'exp_avg_sq') - - elif target_device.type == 'cuda': - assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" - assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" + self.cpu_adam_op.step( + state["step"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + group["bias_correction"], + p.data, + p.grad.data, + state["exp_avg"], + state["exp_avg_sq"], + div_scale, + ) + self._post_update(p, "exp_avg", "exp_avg_sq") + + elif target_device.type == "cuda": + assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda" + assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda" # record the state by group and update at once g_l.append(p.grad.data) p_l.append(p.data) - m_l.append(state['exp_avg']) - v_l.append(state['exp_avg_sq']) + m_l.append(state["exp_avg"]) + v_l.append(state["exp_avg_sq"]) else: raise RuntimeError if len(g_l) > 0: adamw_mode = 1 if self.adamw_mode else 0 - bias_correction = 1 if group['bias_correction'] else 0 - multi_tensor_applier(self.gpu_adam_op, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l], group['lr'], - group['betas'][0], group['betas'][1], group['eps'], group_step, adamw_mode, - bias_correction, group['weight_decay'], div_scale) + bias_correction = 1 if group["bias_correction"] else 0 + multi_tensor_applier( + self.gpu_adam_op, + self._dummy_overflow_buf, + [g_l, p_l, m_l, v_l], + group["lr"], + group["betas"][0], + group["betas"][1], + group["eps"], + group_step, + adamw_mode, + bias_correction, + group["weight_decay"], + div_scale, + ) self._post_step() return loss diff --git a/colossalai/nn/optimizer/lamb.py b/colossalai/nn/optimizer/lamb.py index d5de267f73ee..0d742487f473 100644 --- a/colossalai/nn/optimizer/lamb.py +++ b/colossalai/nn/optimizer/lamb.py @@ -51,27 +51,27 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.') + raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instead.") state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p) + state["exp_avg"] = torch.zeros_like(p) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] - state['step'] += 1 + state["step"] += 1 # Decay the first and second moment running average coefficient # m_t @@ -84,22 +84,22 @@ def step(self, closure=None): # bias_correction2 = 1 - beta2 ** state['step'] # Apply bias to lr to avoid broadcast. # * math.sqrt(bias_correction2) / bias_correction1 - step_size = group['lr'] + step_size = group["lr"] weight_norm = p.data.pow(2).sum().sqrt() - adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) - if group['weight_decay'] != 0: - adam_step.add_(p.data, alpha=group['weight_decay']) + adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) + if group["weight_decay"] != 0: + adam_step.add_(p.data, alpha=group["weight_decay"]) adam_norm = adam_step.pow(2).sum().sqrt() if weight_norm == 0 or adam_norm == 0: trust_ratio = 1 else: trust_ratio = weight_norm / adam_norm - state['weight_norm'] = weight_norm - state['adam_norm'] = adam_norm - state['trust_ratio'] = trust_ratio + state["weight_norm"] = weight_norm + state["adam_norm"] = adam_norm + state["trust_ratio"] = trust_ratio if self.adam: trust_ratio = 1 diff --git a/colossalai/nn/optimizer/lars.py b/colossalai/nn/optimizer/lars.py index 58393fdae4bf..b117c00846d1 100644 --- a/colossalai/nn/optimizer/lars.py +++ b/colossalai/nn/optimizer/lars.py @@ -19,13 +19,9 @@ class Lars(Optimizer): weight_decay (float, optional): weight decay (L2 penalty) (default: 0) """ - def __init__(self, - params: Iterable[torch.nn.Parameter], - lr=1e-3, - momentum=0, - eeta=1e-3, - weight_decay=0, - epsilon=0.0) -> None: + def __init__( + self, params: Iterable[torch.nn.Parameter], lr=1e-3, momentum=0, eeta=1e-3, weight_decay=0, epsilon=0.0 + ) -> None: if not isinstance(lr, float) or lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -54,14 +50,14 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - weight_decay = group['weight_decay'] - momentum = group['momentum'] - eeta = group['eeta'] - lr = group['lr'] - lars = group['lars'] - eps = group['epsilon'] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + eeta = group["eeta"] + lr = group["lr"] + lars = group["lars"] + eps = group["epsilon"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue decayed_grad = p.grad @@ -69,9 +65,11 @@ def step(self, closure=None): if lars: w_norm = torch.norm(p) g_norm = torch.norm(p.grad) - trust_ratio = torch.where(w_norm > 0 and g_norm > 0, - eeta * w_norm / (g_norm + weight_decay * w_norm + eps), - torch.ones_like(w_norm)) + trust_ratio = torch.where( + w_norm > 0 and g_norm > 0, + eeta * w_norm / (g_norm + weight_decay * w_norm + eps), + torch.ones_like(w_norm), + ) trust_ratio.clamp_(0.0, 50) scaled_lr *= trust_ratio.item() if weight_decay != 0: @@ -80,10 +78,10 @@ def step(self, closure=None): if momentum != 0: param_state = self.state[p] - if 'momentum_buffer' not in param_state: - buf = param_state['momentum_buffer'] = torch.clone(decayed_grad).detach() + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = torch.clone(decayed_grad).detach() else: - buf = param_state['momentum_buffer'] + buf = param_state["momentum_buffer"] buf.mul_(momentum).add_(decayed_grad) decayed_grad = buf diff --git a/colossalai/nn/optimizer/nvme_optimizer.py b/colossalai/nn/optimizer/nvme_optimizer.py index fb3a4d87be60..fd02bfb683e1 100644 --- a/colossalai/nn/optimizer/nvme_optimizer.py +++ b/colossalai/nn/optimizer/nvme_optimizer.py @@ -19,13 +19,11 @@ class NVMeOptimizer(torch.optim.Optimizer): Raises: ImportError: Raise if ``tensornvme`` is not installed. - """ + """ - def __init__(self, - params, - defaults: dict, - nvme_offload_fraction: float = 0.0, - offload_dir: Optional[str] = None) -> None: + def __init__( + self, params, defaults: dict, nvme_offload_fraction: float = 0.0, offload_dir: Optional[str] = None + ) -> None: assert 0.0 <= nvme_offload_fraction <= 1.0 super().__init__(params, defaults) self.nvme_offload_fraction = float(nvme_offload_fraction) @@ -34,9 +32,9 @@ def __init__(self, from tensornvme import DiskOffloader from tensornvme._C import get_backends except ModuleNotFoundError: - raise ModuleNotFoundError('Please install tensornvme to use NVMeOptimizer') + raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") self.offload_dir = offload_dir or tempfile.mkdtemp() - backend = 'uring' if 'uring' in get_backends() else 'aio' + backend = "uring" if "uring" in get_backends() else "aio" self.offloader = DiskOffloader(self.offload_dir, 8, backend=backend) else: self.offload_dir = None @@ -53,13 +51,17 @@ def __init__(self, def _get_numel(self) -> int: numel = 0 for group in self.param_groups: - for p in group['params']: + for p in group["params"]: numel += p.storage().size() return numel def _post_state_init(self, param: Parameter) -> None: numel = param.storage().size() - if self.offloader is not None and param.device.type == 'cpu' and numel + self.offloaded_numel <= self.can_offload_numel: + if ( + self.offloader is not None + and param.device.type == "cpu" + and numel + self.offloaded_numel <= self.can_offload_numel + ): self.is_on_nvme[param] = True self.offloaded_numel += numel else: @@ -70,11 +72,11 @@ def _setup_prefetch_params(self) -> List[Parameter]: return assert len(self.prefetch_params) == 0 and len(self.param_to_prefetch_idx) == 0 for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if len(self.state[p]) > 0 and self.is_on_nvme[p]: - assert p.device.type == 'cpu' + assert p.device.type == "cpu" self.param_to_prefetch_idx[p] = len(self.prefetch_params) self.prefetch_params.append(p) @@ -156,7 +158,7 @@ def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) def __del__(self) -> None: - if getattr(self, 'offloader', None) is not None: + if getattr(self, "offloader", None) is not None: del self.offloader if os.path.exists(self.offload_dir): try: diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index e88a1f00a1b7..4754212c1914 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -3,9 +3,9 @@ from .stage_manager import PipelineStageManager __all__ = [ - 'PipelineSchedule', - 'OneForwardOneBackwardSchedule', - 'InterleavedSchedule', - 'PipelineP2PCommunication', - 'PipelineStageManager', + "PipelineSchedule", + "OneForwardOneBackwardSchedule", + "InterleavedSchedule", + "PipelineP2PCommunication", + "PipelineStageManager", ] diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index aed85cf91512..c69bbe6e8521 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -29,11 +29,11 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - Any: object after unpickled """ buf = tensor.numpy().tobytes()[:tensor_size] - if b'cuda' in buf: + if b"cuda" in buf: buf_array = bytearray(buf) device_index = torch.cuda.current_device() # There might be more than one output tensors during forward - for cuda_str in re.finditer(b'cuda', buf_array): + for cuda_str in re.finditer(b"cuda", buf_array): pos = cuda_str.start() buf_array[pos + 5] = 48 + device_index buf = bytes(buf_array) @@ -45,10 +45,9 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - return unpickle -def _broadcast_object_list(object_list: List[Any], - src: int, - group: ProcessGroup, - device: Optional[Union[torch.device, str, int]] = None): +def _broadcast_object_list( + object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None +): """This is a modified version of the broadcast_object_list in torch.distribution The only difference is that object will be move to correct device after unpickled. If local_rank = src, then object list will be sent to rank src. Otherwise, object list will @@ -99,8 +98,8 @@ def _broadcast_object_list(object_list: List[Any], if my_rank == src: object_tensor = torch.cat(tensor_list) else: - object_tensor = torch.empty( # type: ignore[call-overload] - torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] dtype=torch.uint8, ) @@ -114,7 +113,7 @@ def _broadcast_object_list(object_list: List[Any], if my_rank != src: for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset:offset + obj_size] + obj_view = object_tensor[offset : offset + obj_size] obj_view = obj_view.type(torch.uint8) if obj_view.device != torch.device("cpu"): obj_view = obj_view.cpu() @@ -123,8 +122,10 @@ def _broadcast_object_list(object_list: List[Any], unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) # unconsistence in device - if isinstance(unpickle_object, - torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): unpickle_object = unpickle_object.cuda() object_list[i] = unpickle_object @@ -160,7 +161,6 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: class PipelineP2PCommunication: - def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager @@ -192,8 +192,9 @@ def recv_backward(self, next_rank: int = None) -> Any: if next_rank is None: next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - output_tensor_grad = _recv_object(next_rank, cur_rank, - self.stage_manager.get_p2p_process_group(next_rank, cur_rank)) + output_tensor_grad = _recv_object( + next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank) + ) return output_tensor_grad diff --git a/colossalai/pipeline/schedule/__init__.py b/colossalai/pipeline/schedule/__init__.py index 07c0f5927060..6845dc23753b 100644 --- a/colossalai/pipeline/schedule/__init__.py +++ b/colossalai/pipeline/schedule/__init__.py @@ -3,7 +3,7 @@ from .one_f_one_b import OneForwardOneBackwardSchedule __all__ = [ - 'PipelineSchedule', - 'OneForwardOneBackwardSchedule', - 'InterleavedSchedule', + "PipelineSchedule", + "OneForwardOneBackwardSchedule", + "InterleavedSchedule", ] diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 583558551b3c..271b3238f5c4 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -4,24 +4,15 @@ import torch import torch.cuda from torch.nn import Module -from torch.utils._pytree import ( - SUPPORTED_NODES, - LeafSpec, - TreeSpec, - _is_leaf, - _register_pytree_node, - tree_flatten, - tree_map, - tree_unflatten, -) +from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten # this register are for torch under version 1.13.1, maybe removed in the future -def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]: +def _odict_flatten(d: "OrderedDict[Any, Any]") -> Tuple[List[Any], Any]: return list(d.values()), list(d.keys()) -def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]': +def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]": return OrderedDict((key, value) for key, value in zip(context, values)) @@ -45,7 +36,7 @@ def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: # Recursively flatten the children result: List[Any] = [] - children_specs: List['TreeSpec'] = [] + children_specs: List["TreeSpec"] = [] for child in child_pytrees: flat, child_spec = tree_flatten_hf(child) result += flat @@ -87,7 +78,7 @@ def get_batch_size(batch: Any) -> int: for data in data_list: if isinstance(data, torch.Tensor): return data.size(0) - raise RuntimeError('No tensor found in the batch') + raise RuntimeError("No tensor found in the batch") def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any: @@ -104,7 +95,7 @@ def get_micro_batch(batch: Any, start: int, micro_batch_size: int) -> Any: def _get_tensor_slice(x: Any): if isinstance(x, torch.Tensor): - return x[start:start + micro_batch_size] + return x[start : start + micro_batch_size] return x return tree_map(_get_tensor_slice, batch) @@ -175,7 +166,7 @@ def merge_batch(data: List[Any], batch_size_dim=0) -> Any: for elem_batch in zip(*flattened_data): if isinstance(elem_batch[0], torch.Tensor): - if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs + if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs merged_data.append(None) else: merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) diff --git a/colossalai/pipeline/schedule/base.py b/colossalai/pipeline/schedule/base.py index b0fa6e6ad2b8..1bce297862c8 100644 --- a/colossalai/pipeline/schedule/base.py +++ b/colossalai/pipeline/schedule/base.py @@ -8,17 +8,18 @@ class PipelineSchedule: - def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager - def forward_backward_step(self, - model: Module, - data_iter: Iterable, - criterion: Callable[[Any, Any], Tensor], - optimizer: Optional[OptimizerWrapper] = None, - return_loss: bool = False, - return_outputs: bool = False) -> dict: + def forward_backward_step( + self, + model: Module, + data_iter: Iterable, + criterion: Callable[[Any, Any], Tensor], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: """Forward and backward step for pipeline training. Args: diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 6fdb09be5f32..780437155c61 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -16,11 +16,11 @@ class InterleavedSchedule(PipelineSchedule): - def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None: self.num_model_chunks = num_model_chunks - assert num_microbatches % self.num_model_chunks == 0, \ - "Number of microbatches should be an integer multiple of number of model chunks" + assert ( + num_microbatches % self.num_model_chunks == 0 + ), "Number of microbatches should be an integer multiple of number of model chunks" super().__init__(stage_manager) self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches @@ -42,8 +42,7 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" + assert self.batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" self.microbatch_size = self.batch_size // self.num_microbatches def load_micro_batch(self, model_chunk_id: int) -> Any: @@ -72,7 +71,7 @@ def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages if not forward: - model_chunk_id = (self.num_model_chunks - model_chunk_id - 1) + model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id def is_first_stage(self, model_chunk_id: int) -> bool: @@ -161,13 +160,15 @@ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None if not self.is_first_stage(model_chunk_id): self.comm.send_backward(input_object, prev_rank) - def forward_step(self, - model_chunk: Module, - model_chunk_id: int, - input_obj: Optional[dict], - criterion: Callable, - accum_loss: Optional[torch.Tensor] = None, - outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]: + def forward_step( + self, + model_chunk: Module, + model_chunk_id: int, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: model (Module): Model Chunk to be run @@ -195,8 +196,13 @@ def forward_step(self, else: return output_obj - def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], - output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]: + def backward_step( + self, + optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: """Backward one step of the pipeline Args: @@ -235,13 +241,15 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], input_obj_grad[k] = v.grad return input_obj_grad - def forward_backward_step(self, - model_chunk: Module, - data_iter: Iterable, - criterion: Callable[..., Any], - optimizer: Optional[OptimizerWrapper] = None, - return_loss: bool = False, - return_outputs: bool = False) -> dict: + def forward_backward_step( + self, + model_chunk: Module, + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: """Runs interleaved 1F1B schedule, with communication between pipeline stages. Args: @@ -321,7 +329,7 @@ def forward_backward_step(self, # Run 1F1B in steady state. for i in range(num_microbatches_remaining): model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True) - last_iteration = (i == (num_microbatches_remaining - 1)) + last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) if forward_only: @@ -369,4 +377,4 @@ def forward_backward_step(self, if outputs is not None: outputs = merge_batch(outputs) - return {'loss': accum_loss, 'outputs': outputs} + return {"loss": accum_loss, "outputs": outputs} diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index fbd0f9f0d4c0..4eaf135fd5db 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -25,11 +25,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): - - def __init__(self, - stage_manager: PipelineStageManager, - num_microbatches: Optional[int] = None, - microbatch_size: Optional[int] = None) -> None: + def __init__( + self, + stage_manager: PipelineStageManager, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + ) -> None: """1F1B pipeline schedule. Args: @@ -38,8 +39,9 @@ def __init__(self, microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None. """ super().__init__(stage_manager) - assert num_microbatches is not None or microbatch_size is not None, \ - "Either num_microbatches or microbatch_size should be provided" + assert ( + num_microbatches is not None or microbatch_size is not None + ), "Either num_microbatches or microbatch_size should be provided" self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches self.microbatch_size = microbatch_size @@ -62,12 +64,12 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 if not self._use_microbatch_size: - assert self.batch_size % self.num_microbatches == 0, \ - "Batch size should divided by the number of microbatches" + assert ( + self.batch_size % self.num_microbatches == 0 + ), "Batch size should divided by the number of microbatches" self.microbatch_size = self.batch_size // self.num_microbatches else: - assert self.batch_size % self.microbatch_size == 0, \ - "Batch size should divided by the microbatch size" + assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" self.num_microbatches = self.batch_size // self.microbatch_size def load_micro_batch(self) -> Any: @@ -136,12 +138,14 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: if not self.stage_manager.is_first_stage(): self.comm.send_backward(input_object, prev_rank) - def forward_step(self, - model: Module, - input_obj: Optional[dict], - criterion: Callable, - accum_loss: Optional[torch.Tensor] = None, - outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]: + def forward_step( + self, + model: Module, + input_obj: Optional[dict], + criterion: Callable, + accum_loss: Optional[torch.Tensor] = None, + outputs: Optional[List[Any]] = None, + ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: @@ -159,7 +163,6 @@ def forward_step(self, # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict output_obj = model_forward(model, micro_batch, input_obj) if self.stage_manager.is_last_stage(): - loss = criterion(output_obj, micro_batch) / self.num_microbatches if accum_loss is not None: accum_loss.add_(loss.detach()) @@ -169,8 +172,13 @@ def forward_step(self, else: return output_obj - def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], - output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]: + def backward_step( + self, + optimizer: OptimizerWrapper, + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: """Backward one step of the pipeline Args: @@ -208,13 +216,15 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict], input_obj_grad[k] = v.grad return input_obj_grad - def forward_backward_step(self, - model: Module, - data_iter: Iterable, - criterion: Callable[..., Any], - optimizer: Optional[OptimizerWrapper] = None, - return_loss: bool = False, - return_outputs: bool = False) -> dict: + def forward_backward_step( + self, + model: Module, + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Args: @@ -273,7 +283,7 @@ def forward_backward_step(self, # Run 1F1B in steady state. for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) + last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) if forward_only: @@ -316,5 +326,5 @@ def forward_backward_step(self, if outputs is not None: if isinstance(model, ModelWrapper): model = model.unwrap() - outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0)) - return {'loss': accum_loss, 'outputs': outputs} + outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) + return {"loss": accum_loss, "outputs": outputs} diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 6ba7dc629958..b79867a2c651 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from typing import Dict, List, Optional, Tuple import torch.distributed as dist @@ -28,13 +27,11 @@ def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bo # init prev and next coord coord = self.pg_mesh.coordinate() # the prev rank of rank0 is the last rank - prev_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] - self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap') + prev_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1 :] + self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode="wrap") # the next rank of the last rank is rank0 - next_coord = coord[: self.pipeline_axis] + \ - (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] - self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap') + next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :] + self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap") # init p2p process groups stages = list(range(self.num_stages)) diff --git a/colossalai/shardformer/_utils.py b/colossalai/shardformer/_utils.py index c553080de0a0..96d6cea21075 100644 --- a/colossalai/shardformer/_utils.py +++ b/colossalai/shardformer/_utils.py @@ -13,14 +13,14 @@ def get_obj_list_element(obj, attr: str): attr (str): The suffix of the attribute to get """ - re_pattern = r'\[\d+\]' + re_pattern = r"\[\d+\]" prog = re.compile(re_pattern) result = prog.search(attr) if result: matched_brackets = result.group() - matched_index = matched_brackets.replace('[', '') - matched_index = matched_index.replace(']', '') - attr_ = attr.replace(matched_brackets, '') + matched_index = matched_brackets.replace("[", "") + matched_index = matched_index.replace("]", "") + attr_ = attr.replace(matched_brackets, "") container_obj = getattr(obj, attr_) obj = container_obj[int(matched_index)] else: @@ -38,14 +38,14 @@ def set_obj_list_element(obj, attr: str, value): obj (object): The object to set attr (str): the string including a list index like `layers[0]` """ - re_pattern = r'\[\d+\]' + re_pattern = r"\[\d+\]" prog = re.compile(re_pattern) result = prog.search(attr) if result: matched_brackets = result.group() - matched_index = matched_brackets.replace('[', '') - matched_index = matched_index.replace(']', '') - attr_ = attr.replace(matched_brackets, '') + matched_index = matched_brackets.replace("[", "") + matched_index = matched_index.replace("]", "") + attr_ = attr.replace(matched_brackets, "") container_obj = getattr(obj, attr_) container_obj[int(matched_index)] = value else: @@ -60,7 +60,7 @@ def hasattr_(obj, attr: str): obj (object): The object to check attr (str): The multi level attr to check """ - attrs = attr.split('.') + attrs = attr.split(".") for a in attrs: try: obj = get_obj_list_element(obj, a) @@ -80,7 +80,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): ignore (bool): Whether to ignore when the attr doesn't exist """ - attrs = attr.split('.') + attrs = attr.split(".") for a in attrs[:-1]: try: obj = get_obj_list_element(obj, a) @@ -101,7 +101,7 @@ def getattr_(obj, attr: str, ignore: bool = False): ignore (bool): Whether to ignore when the attr doesn't exist """ - attrs = attr.split('.') + attrs = attr.split(".") for a in attrs: try: obj = get_obj_list_element(obj, a) diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py index 81be2017855c..b03e6201dce8 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.py +++ b/colossalai/shardformer/examples/convergence_benchmark.py @@ -7,7 +7,7 @@ import torch.distributed as dist from data import GLUEDataBuilder from torch import nn -from torch.optim import Adam, AdamW, Optimizer +from torch.optim import Adam, Optimizer from torch.utils._pytree import tree_map from torch.utils.data import DataLoader from tqdm import tqdm @@ -15,12 +15,10 @@ import colossalai from colossalai.cluster import DistCoordinator -from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer import ShardConfig, ShardFormer def to_device(x: Any, device: torch.device) -> Any: - def _to(t: Any): if isinstance(t, torch.Tensor): return t.to(device) @@ -34,10 +32,12 @@ def train(args): coordinator = DistCoordinator() # prepare for data and dataset - data_builder = GLUEDataBuilder(model_name_or_path=args.pretrain, - task_name=args.task, - train_batch_size=args.batch_size, - eval_batch_size=args.batch_size) + data_builder = GLUEDataBuilder( + model_name_or_path=args.pretrain, + task_name=args.task, + train_batch_size=args.batch_size, + eval_batch_size=args.batch_size, + ) train_dataloader = data_builder.train_dataloader() test_dataloader = data_builder.test_dataloader() @@ -49,10 +49,10 @@ def train(args): # if multiple GPUs, shard the model if dist.get_world_size() > 1: - tp_group = dist.new_group(backend='nccl') - shard_config = ShardConfig(tensor_parallel_process_group=tp_group, - enable_tensor_parallelism=True, - enable_all_optimization=True) + tp_group = dist.new_group(backend="nccl") + shard_config = ShardConfig( + tensor_parallel_process_group=tp_group, enable_tensor_parallelism=True, enable_all_optimization=True + ) shard_former = ShardFormer(shard_config=shard_config) model, _ = shard_former.optimize(model) @@ -64,21 +64,40 @@ def train(args): num_warmup_steps=math.ceil(max_steps * args.warmup_fraction), num_training_steps=max_steps, ) - fit(model, optim, lr_scheduler, train_dataloader, args.max_epochs, args.accumulation_steps, args.batch_size, - coordinator) - results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, - coordinator) + fit( + model, + optim, + lr_scheduler, + train_dataloader, + args.max_epochs, + args.accumulation_steps, + args.batch_size, + coordinator, + ) + results = evaluate_model( + model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, coordinator + ) if coordinator.is_master(): print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' - - -def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max_epochs, accumulation_steps, batch_size, - coordinator): - step_bar = tqdm(range(len(train_dataloader) // accumulation_steps * max_epochs), - desc=f'steps', - disable=not coordinator.is_master()) + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +def fit( + model: nn.Module, + optimizer: Optimizer, + scheduler, + train_dataloader, + max_epochs, + accumulation_steps, + batch_size, + coordinator, +): + step_bar = tqdm( + range(len(train_dataloader) // accumulation_steps * max_epochs), + desc=f"steps", + disable=not coordinator.is_master(), + ) total_loss = 0 for epoch in range(max_epochs): model.train() @@ -93,19 +112,23 @@ def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max optimizer.step() scheduler.step() optimizer.zero_grad() - step_bar.set_postfix({ - 'epoch': epoch, - 'loss': total_loss / batch_size, - 'lr': scheduler.get_last_lr()[0] - }) + step_bar.set_postfix( + {"epoch": epoch, "loss": total_loss / batch_size, "lr": scheduler.get_last_lr()[0]} + ) total_loss = 0 step_bar.update() # evaluate @torch.no_grad() -def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, - task_name: str, eval_splits: List[str], coordinator: DistCoordinator): +def evaluate_model( + model: nn.Module, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + coordinator: DistCoordinator, +): metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) model.eval() @@ -127,7 +150,7 @@ def evaluate_subset(dataloader: DataLoader): results = metric.compute() if coordinator.is_master(): - results['loss'] = accum_loss.item() / (len(dataloader) * dataloader.batch_size) + results["loss"] = accum_loss.item() / (len(dataloader) * dataloader.batch_size) return results if isinstance(test_dataloader, DataLoader): @@ -137,21 +160,21 @@ def evaluate_subset(dataloader: DataLoader): final_results = {} for split, sub_loader in zip(eval_splits, test_dataloader): results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) return final_results -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('--model', type=str, default="bert") - parser.add_argument('--pretrain', type=str, default="bert-base-uncased") - parser.add_argument('--max_epochs', type=int, default=1) - parser.add_argument('--batch_size', type=int, default=4) - parser.add_argument('--lr', type=float, default=2.4e-5) - parser.add_argument('--fused_layernorm', type=bool, default=False) - parser.add_argument('--accumulation_steps', type=int, default=8) - parser.add_argument('--warmup_fraction', type=float, default=0.03) - parser.add_argument('--target_f1', type=float, default=None) + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument("--model", type=str, default="bert") + parser.add_argument("--pretrain", type=str, default="bert-base-uncased") + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--lr", type=float, default=2.4e-5) + parser.add_argument("--fused_layernorm", type=bool, default=False) + parser.add_argument("--accumulation_steps", type=int, default=8) + parser.add_argument("--warmup_fraction", type=float, default=0.03) + parser.add_argument("--target_f1", type=float, default=None) args = parser.parse_args() train(args) diff --git a/colossalai/shardformer/examples/data.py b/colossalai/shardformer/examples/data.py index 6296d4be4eb0..ddf44a874659 100644 --- a/colossalai/shardformer/examples/data.py +++ b/colossalai/shardformer/examples/data.py @@ -6,7 +6,6 @@ class GLUEDataBuilder: - task_text_field_map = { "cola": ["sentence"], "sst2": ["sentence"], @@ -86,14 +85,12 @@ def prepare_data(self): def train_dataloader(self): if self.plugin == None: - return self.native_prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) - return self.plugin.prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) + return self.native_prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) def val_dataloader(self): if self.plugin == None: @@ -118,7 +115,6 @@ def test_dataloader(self): ] def convert_to_features(self, example_batch): - # Either encode single sentence or sentence pairs if len(self.text_fields) > 1: texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) @@ -126,10 +122,9 @@ def convert_to_features(self, example_batch): texts_or_text_pairs = example_batch[self.text_fields[0]] # Tokenize the text/text pairs - features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, - max_length=self.max_seq_length, - padding='max_length', - truncation=True) + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) # Rename label to labels to make it easier to pass to model forward features["labels"] = example_batch["label"] @@ -137,10 +132,6 @@ def convert_to_features(self, example_batch): return features def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False): - - return DataLoader(dataset, - batch_size=batch_size, - sampler=None, - shuffle=shuffle, - drop_last=drop_last, - pin_memory=pin_memory) + return DataLoader( + dataset, batch_size=batch_size, sampler=None, shuffle=shuffle, drop_last=drop_last, pin_memory=pin_memory + ) diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py index 2f186709d946..81215dcdf5d4 100644 --- a/colossalai/shardformer/examples/performance_benchmark.py +++ b/colossalai/shardformer/examples/performance_benchmark.py @@ -20,35 +20,35 @@ def data_gen_for_sequence_classification(batch_size, seq_length): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen(batch_size, seq_length) - data['labels'] = torch.ones((batch_size), dtype=torch.long) + data["labels"] = torch.ones((batch_size), dtype=torch.long) return data -MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, - num_attention_heads=4, - max_position_embeddings=128, - num_labels=16, - pad_token_id=2) +MODEL_CONFIG = transformers.LlamaConfig( + num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16, + pad_token_id=2, +) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64 model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG) # vary seq length for fixed head and batch=4 configs = [ - triton.testing.Benchmark(x_names=['N_CTX'], - x_vals=[2**i for i in range(8, 13)], - line_arg='provider', - line_vals=['org_model', 'shard_model'], - line_names=['org_model', 'shard_model'], - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'lama_for_sequence_classification-batch-{BATCH}', - args={ - 'BATCH': BATCH, - 'dtype': torch.float16, - 'model_func': model_func - }) + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(8, 13)], + line_arg="provider", + line_vals=["org_model", "shard_model"], + line_names=["org_model", "shard_model"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"lama_for_sequence_classification-batch-{BATCH}", + args={"BATCH": BATCH, "dtype": torch.float16, "model_func": model_func}, + ) ] @@ -85,4 +85,4 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d # torchrun --standalone --nproc_per_node=2 performance_benchmark.py if __name__ == "__main__": colossalai.launch_from_torch({}) - bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0) + bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index c4586d18b90c..a134a2cbd21c 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -7,7 +7,17 @@ from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', - 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm', 'FusedRMSNorm', 'FusedLinear1D_Col', 'ParallelModule' + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + "GPT2FusedLinearConv1D_Col", + "GPT2FusedLinearConv1D_Row", + "DropoutForParallelInput", + "DropoutForReplicatedInput", + "cross_entropy_1d", + "FusedLayerNorm", + "FusedRMSNorm", + "FusedLinear1D_Col", + "ParallelModule", ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 45b305733813..5ec48096183b 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,5 +1,3 @@ -from typing import Any - import torch import torch.distributed as dist import torch.nn.functional as F @@ -22,7 +20,7 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function): If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps: a value added to the denominator for numerical stability - """ + """ @staticmethod def forward(ctx, input, weight, bias, normalized_shape, eps): @@ -31,8 +29,9 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() - output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, - bias_, ctx.eps) + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @@ -40,11 +39,9 @@ def forward(ctx, input, weight, bias, normalized_shape, eps): def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias \ - = fused_mix_prec_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + ) return grad_input, grad_weight, grad_bias, None, None @@ -195,8 +192,9 @@ def backward(ctx, grad_output): input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(input_.shape, dtype=input_parallel.dtype, - device=input_parallel.device).contiguous() + output = torch.empty( + input_.shape, dtype=input_parallel.dtype, device=input_parallel.device + ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # reduce-scatter scheduled first and have GPU resources allocated @@ -260,8 +258,9 @@ def forward(ctx, input_, process_group, dim): # do reduce-scatter new_shape = list(input_.shape) - assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ - f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group) input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)] output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) @@ -329,8 +328,9 @@ def backward(ctx, grad_output): input_list = [ item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim) ] - output = torch.empty(input_.shape, dtype=input_parallel.dtype, - device=input_parallel.device).contiguous() + output = torch.empty( + input_.shape, dtype=input_parallel.dtype, device=input_parallel.device + ).contiguous() handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # reduce-scatter scheduled first and have GPU resources allocated @@ -473,9 +473,10 @@ def _split(input_, dim=-1, process_group=None): # Split along last dimension. dim_size = input_.size(dim) - assert dim_size % world_size == 0, \ - f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ - f'cannot split tensor evenly' + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = dist.get_rank(process_group) @@ -502,7 +503,7 @@ def _gather(input_, dim=-1, process_group=None): def _reduce_scatter(input_, dim=1, process_group=None): - """ Do reduce-scatter operation. + """Do reduce-scatter operation. Args: input_ (`torch.Tensor`): The input tensor from sequence parallel region. @@ -515,8 +516,9 @@ def _reduce_scatter(input_, dim=1, process_group=None): # reduce-scatter new_shape = list(input_.shape) - assert new_shape[dim] % dist.get_world_size(process_group) == 0, \ - f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). ' + assert ( + new_shape[dim] % dist.get_world_size(process_group) == 0 + ), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). " new_shape[dim] = new_shape[dim] // world_size output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device) dist.reduce_scatter(output, input_, group=process_group) @@ -532,20 +534,24 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) -def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, - overlap): - return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, - async_grad_reduce_scatter, dim, overlap) +def linear_gather_forward_reducescatter_backward( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap +): + return _LinearWithGatherForwardReduceScatterBackward.apply( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + ) def linear_reducescatter_forward_gather_backward(input_, process_group, dim): return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim) -def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim, - overlap): - return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group, - async_grad_reduce_scatter, dim, overlap) +def matmul_gather_forward_reducescatter_backward( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap +): + return _MatmulWithGatherForwardReduceScatterBackward.apply( + input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap + ) def gather_forward_split_backward(input_, dim, process_group): diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 2625fe97889a..8771913ee62f 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -7,7 +7,7 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput'] +__all__ = ["DropoutForParallelInput", "DropoutForReplicatedInput"] class DropoutForParallelInput(ParallelModule, nn.Dropout): @@ -31,8 +31,9 @@ def __init__(self, p: float = 0.5, inplace: bool = False, process_group: Process self.randomizer = create_randomizer_with_offset(seed, process_group=process_group) @staticmethod - def from_native_module(module: nn.Dropout, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput": + def from_native_module( + module: nn.Dropout, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "DropoutForParallelInput": """ Create a DropoutForParallelInput layer from a native dropout layer. """ @@ -68,8 +69,8 @@ def __init__(self, p: float = 0.5, inplace: bool = False, process_group: Process @staticmethod def from_native_module( - module: nn.Dropout, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput": + module: nn.Dropout, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "DropoutForReplicatedInput": """ Create a Dropout1D layer from a native dropout layer. """ diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 847ca175ad57..62163cb009aa 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -24,7 +24,7 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['Embedding1D', 'VocabParallelEmbedding1D'] +__all__ = ["Embedding1D", "VocabParallelEmbedding1D"] class Embedding1D(ParallelModule): @@ -57,18 +57,20 @@ class Embedding1D(ParallelModule): `init `_ """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = True, - weight: Optional[nn.Parameter] = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = True, + weight: Optional[nn.Parameter] = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings @@ -86,7 +88,7 @@ def __init__(self, # Parameters. if weight is None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -100,10 +102,9 @@ def __init__(self, self.reset_parameters(weight_initializer) @staticmethod - def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None, - *args, - **kwargs) -> "Embedding1D": + def from_native_module( + module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]] = None, *args, **kwargs + ) -> "Embedding1D": r""" Build a 1D parallelized Embedding from a native nn.Embedding module. """ @@ -123,19 +124,21 @@ def from_native_module(module: nn.Embedding, if sparse: raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") - embedding = Embedding1D(num_embeddings=num_embedding, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - process_group=process_group, - dtype=dtype, - device=device, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - weight=module.weight, - *args, - **kwargs) + embedding = Embedding1D( + num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + weight=module.weight, + *args, + **kwargs, + ) return embedding @@ -188,17 +191,19 @@ class VocabParallelEmbedding1D(ParallelModule): `init `_. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - weight: Optional[nn.Parameter] = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight: Optional[nn.Parameter] = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs, + ): super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim @@ -223,7 +228,7 @@ def __init__(self, # parameter if weight is None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -236,8 +241,9 @@ def __init__(self, self.reset_parameters(weight_initializer) @staticmethod - def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a native pytorch embedding module to a parallel module. """ @@ -250,19 +256,20 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, # ensure only one process group is used if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] # create the parallel module - vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - device=device, - process_group=process_group, - weight=module.weight, - *args, - **kwargs) + vocab_embedding_1d = VocabParallelEmbedding1D( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + process_group=process_group, + weight=module.weight, + *args, + **kwargs, + ) return vocab_embedding_1d @@ -273,8 +280,11 @@ def reset_parameters(self, weight_initializer) -> None: self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + if ( + self.padding_idx is not None + and self.padding_idx >= self.vocab_start_index + and self.padding_idx < self.vocab_end_index + ): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) @@ -294,11 +304,12 @@ def forward(self, input_: Tensor) -> Tensor: masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) + output_parallel = F.embedding( + masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs + ) # Mask the output embedding. - output_parallel[input_mask, :] = 0. + output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_forward(output_parallel, self.process_group) return output diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 111d51b3f8d8..cf2003877d3c 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -33,7 +33,7 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['Linear1D_Col', 'Linear1D_Row'] +__all__ = ["Linear1D_Col", "Linear1D_Row"] class Linear1D_Col(ParallelModule): @@ -65,22 +65,24 @@ class Linear1D_Col(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = False, - seq_parallel: bool = False, - seq_parallel_dim: int = 1, - overlap: torch.cuda.Stream = None, - skip_bias_add: bool = False, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + seq_parallel: bool = False, + seq_parallel_dim: int = 1, + overlap: torch.cuda.Stream = None, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() # Keep input parameters @@ -95,7 +97,7 @@ def __init__(self, self.process_group = process_group if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -103,13 +105,13 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -135,8 +137,9 @@ def __init__(self, self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ @@ -149,8 +152,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -159,17 +161,20 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis if out_features % tp_size != 0: raise ValueError( - f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!") - - linear_1d = Linear1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = Linear1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) return linear_1d @@ -181,9 +186,11 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Set up backprop all-reduce. input_parallel = input_ @@ -191,9 +198,9 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: # Matrix multiply. bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: - output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, - self.seq_parallel_dim, self.overlap) + output_parallel = linear_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap + ) else: output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) @@ -210,7 +217,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: class Linear1D_Row(ParallelModule): - r""" Linear layer with row parallelism + r"""Linear layer with row parallelism Args: in_features (int): size of each input sample. @@ -231,22 +238,24 @@ class Linear1D_Row(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - seq_parallel: bool = False, - seq_parallel_dim: int = 1, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + seq_parallel: bool = False, + seq_parallel_dim: int = 1, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + ): super().__init__() self.stream_chunk_num = stream_chunk_num @@ -262,7 +271,7 @@ def __init__(self, self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -270,14 +279,14 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -304,8 +313,9 @@ def __init__(self, self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ @@ -318,8 +328,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -328,17 +337,20 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - - linear_1d = Linear1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = Linear1D_Row( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) return linear_1d @@ -366,14 +378,18 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) input_ = input_ else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + assert ( + divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions + ) input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) if self.stream_chunk_num > 1: @@ -384,9 +400,9 @@ def forward(self, input_: Tensor) -> Tensor: handle_list = [] for i in range(self.stream_chunk_num): output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=self.process_group, - async_op=True) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=self.process_group, async_op=True + ) handle_list.append(handle) # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) for handle in handle_list: @@ -395,8 +411,9 @@ def forward(self, input_: Tensor) -> Tensor: else: output_parallel = F.linear(input_, self.weight) if self.seq_parallel: - output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, - self.seq_parallel_dim) + output = linear_reducescatter_forward_gather_backward( + output_parallel, self.process_group, self.seq_parallel_dim + ) else: output = reduce_forward(output_parallel, self.process_group) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 7e3f6926b6d4..848e4a3a1f7d 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -3,7 +3,7 @@ from torch.autograd import Function from torch.distributed import ProcessGroup -__all__ = ['DistCrossEntropy', 'cross_entropy_1d'] +__all__ = ["DistCrossEntropy", "cross_entropy_1d"] class DistCrossEntropy(Function): @@ -61,8 +61,9 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: masked_target_1d = masked_target.view(-1) # extract the x[class] and set the x[other device] to zero - pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), - masked_target_1d] + pred_logits_1d = logits_2d[ + torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), masked_target_1d + ] pred_logits_1d = pred_logits_1d.clone().contiguous() pred_logits = pred_logits_1d.view_as(target) pred_logits[mask] = 0.0 @@ -102,8 +103,7 @@ def backward(ctx, grad_output): return grad_logits, None, None -def cross_entropy_1d(vocab_logits: torch.Tensor, - labels: torch.Tensor, - ignore_index: int = -100, - process_group: ProcessGroup = None) -> torch.Tensor: +def cross_entropy_1d( + vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None +) -> torch.Tensor: return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 0aea295664a7..19b973be8679 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -1,28 +1,49 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import torch import torch.nn as nn from colossalai.lazy import LazyInitContext -__all__ = ['FusedLayerNorm', 'FusedRMSNorm'] +__all__ = ["FusedLayerNorm", "FusedRMSNorm"] FAST_LAYERNORM_SUPPORTED_SIZE = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, - 25600, 30720, 32768, 40960, 49152, 65536 + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, ] -class FusedLayerNorm(): +class FusedLayerNorm: r""" This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. """ def __init__(self) -> None: raise NotImplementedError( - 'FusedLayerNorm is not implemented as a physical class. ' - 'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.' + "FusedLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex." ) @staticmethod @@ -32,10 +53,11 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: """ # check if apex is installed try: - import apex + pass except ImportError: raise ImportError( - 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel" + ) LazyInitContext.materialize(module) # get the attributes of the module @@ -57,23 +79,24 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: else: from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm - layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, - elementwise_affine=elementwise_affine).to(dtype).to(device) + layernorm = ( + ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) + ) layernorm.weight = module.weight layernorm.bias = module.bias return layernorm -class FusedRMSNorm(): +class FusedRMSNorm: """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ def __init__(self) -> None: raise NotImplementedError( - 'FusedRMSNorm is not implemented as a physical class. ' - 'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.' + "FusedRMSNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex." ) @staticmethod @@ -82,7 +105,7 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm except ImportError: raise ImportError( - 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel' + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" ) LazyInitContext.materialize(module) diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index 4f391920e29b..6c0d83cc7a20 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -19,18 +19,16 @@ is_customized_distributed_tensor, is_distributed_tensor, sharded_tensor_to_param, - to_global, - to_global_for_customized_distributed_tensor, ) -__all__ = ['ParallelModule'] +__all__ = ["ParallelModule"] class ParallelModule(nn.Module, ABC): - @abstractmethod - def from_native_module(module: nn.Module, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None + ) -> "ParallelModule": """ Convert a native PyTorch module to a parallelized module. @@ -40,7 +38,6 @@ def from_native_module(module: nn.Module, If this is a list, the process group at the ith index of the list will correspond to the process group in the ith axis of the device mesh. Defaults to None, which means the global process group. """ - pass def _save_to_state_dict(self, destination, prefix, keep_vars): r"""Saves module state to `destination` dictionary, containing a state @@ -66,8 +63,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: destination[extra_state_key] = self.get_extra_state() - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): r"""Copies parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this @@ -112,9 +110,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss if key in state_dict: input_param = state_dict[key] if not torch.overrides.is_tensor_like(input_param): - error_msgs.append('While copying the parameter named "{}", ' - 'expected torch.Tensor or Tensor-like object from checkpoint but ' - 'received {}'.format(key, type(input_param))) + error_msgs.append( + 'While copying the parameter named "{}", ' + "expected torch.Tensor or Tensor-like object from checkpoint but " + "received {}".format(key, type(input_param)) + ) continue if is_distributed_tensor(param): @@ -136,19 +136,22 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss if not is_param_lazy and input_param.shape != param.shape: # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format(key, input_param.shape, param.shape)) + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(key, input_param.shape, param.shape) + ) continue try: with torch.no_grad(): param.copy_(input_param) except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), - ex.args)) + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args) + ) elif strict: missing_keys.append(key) @@ -164,7 +167,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss if strict: for key in state_dict.keys(): if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] - input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + input_name = key[len(prefix) :] + input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child if input_name not in self._modules and input_name not in local_state: unexpected_keys.append(key) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py index 5ce77805f9b8..12476d050600 100644 --- a/colossalai/shardformer/layer/qkv_fused_linear.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -36,17 +36,16 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row', 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row'] +__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"] # ==================================== # For GPT Only # ==================================== -def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor, - n_fused: int, - process_group: ProcessGroup, - is_transposed: bool = False): +def split_fused_qkv_in_gpt2_style( + qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False +): """ The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. @@ -85,10 +84,9 @@ def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor, return weight_of_current_rank -def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor, - n_fused: int, - process_group: ProcessGroup, - is_transposed: bool = False): +def gather_fused_qkv_in_gpt2_style( + qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False +): """ The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. @@ -167,23 +165,25 @@ class GPT2FusedLinearConv1D_Col(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - async_communication: bool = False, - gather_output: bool = False, - seq_parallel: bool = False, - overlap: bool = False, - skip_bias_add: bool = False, - n_fused: int = 3, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + async_communication: bool = False, + gather_output: bool = False, + seq_parallel: bool = False, + overlap: bool = False, + skip_bias_add: bool = False, + n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() # Keep input parameters @@ -199,7 +199,7 @@ def __init__(self, self.async_communication = async_communication if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -207,14 +207,14 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -249,8 +249,9 @@ def gather_fn(tensor): self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. @@ -268,8 +269,7 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -278,17 +278,20 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis if out_features % tp_size != 0: raise ValueError( - f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!") - - linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = GPT2FusedLinearConv1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) return linear_1d @@ -300,22 +303,26 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[0], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[0] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Matrix multiply. bias = self.bias if not self.skip_bias_add else None if self.seq_parallel: input_parallel = input_ - output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias, - self.process_group, True, 1, self.overlap) + output_parallel = matmul_gather_forward_reducescatter_backward( + input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap + ) else: # Set up backprop all-reduce. input_parallel = reduce_backward(input_, self.process_group) - output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, - self.async_communication) + output_parallel = matmul_with_async_comm( + input_parallel, self.weight, bias, self.process_group, self.async_communication + ) if self.gather_output: # All-gather across the partitions. @@ -330,7 +337,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: class GPT2FusedLinearConv1D_Row(ParallelModule): - r""" Linear layer with row parallelism. + r"""Linear layer with row parallelism. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. Args: @@ -351,21 +358,23 @@ class GPT2FusedLinearConv1D_Row(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - seq_parallel: bool = False, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + seq_parallel: bool = False, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1, + ): super().__init__() self.stream_chunk_num = stream_chunk_num @@ -380,7 +389,7 @@ def __init__(self, self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -391,14 +400,14 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -424,8 +433,9 @@ def __init__(self, self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ @@ -438,8 +448,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -448,17 +457,20 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - - linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + + linear_1d = GPT2FusedLinearConv1D_Row( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) return linear_1d @@ -485,14 +497,18 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[0], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[0]) + assert ( + input_.shape[-1] == self.weight.shape[0] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[0] + ) input_ = input_ else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions) + assert ( + divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0] + ), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions + ) input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) if self.stream_chunk_num > 1: @@ -503,9 +519,9 @@ def forward(self, input_: Tensor) -> Tensor: handle_list = [] for i in range(self.stream_chunk_num): output_parallel_list[i] = torch.matmul(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=self.process_group, - async_op=True) + handle = torch.distributed.all_reduce( + output_parallel_list[i], group=self.process_group, async_op=True + ) handle_list.append(handle) # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) for handle in handle_list: @@ -559,21 +575,23 @@ class FusedLinear1D_Col(ParallelModule): `init `_. """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - async_communication: bool = False, - gather_output: bool = False, - skip_bias_add: bool = False, - n_fused: int = 3, - weight: Optional[Parameter] = None, - bias_: Optional[Parameter] = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + async_communication: bool = False, + gather_output: bool = False, + skip_bias_add: bool = False, + n_fused: int = 3, + weight: Optional[Parameter] = None, + bias_: Optional[Parameter] = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + ): super().__init__() # Keep input parameters self.in_features = in_features @@ -586,7 +604,7 @@ def __init__(self, self.async_communication = async_communication if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') + raise ValueError("cannot skip bias addition if bias is None") # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -594,14 +612,14 @@ def __init__(self, # sanity check if weight is not None: - assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None' + assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None" else: - assert bias_ is None, 'bias_ must be None if weight is None' + assert bias_ is None, "bias_ must be None if weight is None" # Parameters. if weight is None: # Initialize weight. - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) else: weight.data = weight.data.to(device=device, dtype=dtype) @@ -636,8 +654,9 @@ def gather_fn(tensor): self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, - *args, **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs + ) -> ParallelModule: r""" Convert a fused `torch.nn.linear` layer to a parallelized linear layer. @@ -654,19 +673,20 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] - linear_1d = FusedLinear1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - weight=module.weight, - bias_=module.bias, - *args, - **kwargs) + linear_1d = FusedLinear1D_Col( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + weight=module.weight, + bias_=module.bias, + *args, + **kwargs, + ) # # TODO: copy the sharded weights # with torch.no_grad(): @@ -693,9 +713,11 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) + assert ( + input_.shape[-1] == self.weight.shape[-1] + ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format( + input_.shape, self.weight.shape, self.weight.shape[-1] + ) # Set up backprop all-reduce. # input_parallel = reduce_backward(input_, self.process_group) input_parallel = input_ diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 577bef076a7e..c3d8501cdeae 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -3,7 +3,6 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import _get_global_rank class Randomizer: @@ -172,10 +171,9 @@ def synchronize_index(process_group: ProcessGroup = None): Randomizer._INDEX = index_tensor.item() -def create_randomizer_with_offset(seed: int, - process_group: ProcessGroup = None, - offset_by_rank: bool = True, - offset_by_index: bool = True): +def create_randomizer_with_offset( + seed: int, process_group: ProcessGroup = None, offset_by_rank: bool = True, offset_by_index: bool = True +): """ Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer. @@ -197,9 +195,11 @@ def create_randomizer_with_offset(seed: int, if offset_by_index: # check if the randomizer index is synchronized is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group) - assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes." - "This is not allowed when we want to create a randomizer with offset by index." - "Please call Randomizer.synchronize_index() first.") + assert is_synchronized, ( + "We detect that the randomizer index is not synchronized across processes." + "This is not allowed when we want to create a randomizer with offset by index." + "Please call Randomizer.synchronize_index() first." + ) base_seed += Randomizer.index() Randomizer.increment_index() diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 30855a622adb..7411e1d0ec46 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -34,10 +34,10 @@ class BertPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Bert models under pipeline setting. - ''' + """ @staticmethod def bert_model_forward( @@ -56,36 +56,37 @@ def bert_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage + hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): # TODO(jianghai): add explaination of the output here. r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: @@ -118,13 +119,13 @@ def bert_model_forward( # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # past_key_values_length @@ -173,7 +174,8 @@ def bert_model_forward( if self.encoder.gradient_checkpointing and self.encoder.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False next_decoder_cache = () if use_cache else None @@ -184,12 +186,13 @@ def bert_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx): if stage_manager.is_first_stage() and idx == 0: @@ -204,7 +207,6 @@ def bert_model_forward( if self.encoder.gradient_checkpointing and self.encoder.training: def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) @@ -234,14 +236,13 @@ def custom_forward(*inputs): if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + \ - (layer_outputs[2],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config is not None and shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -268,7 +269,7 @@ def custom_forward(*inputs): else: # intermediate stage always return dict return { - 'hidden_states': hidden_states, + "hidden_states": hidden_states, } @staticmethod @@ -295,10 +296,10 @@ def bert_for_pretraining_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai) left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False outputs = BertPipelineForwards.bert_model_forward( @@ -317,10 +318,6 @@ def bert_for_pretraining_forward( stage_index=stage_index, shard_config=shard_config, ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): sequence_output, pooled_output = outputs[:2] @@ -345,11 +342,11 @@ def bert_for_pretraining_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') + hidden_states = outputs.get("hidden_states") # intermediate stage always return dict return { - 'hidden_states': hidden_states, + "hidden_states": hidden_states, } @staticmethod @@ -375,39 +372,39 @@ def bert_lm_head_model_forward( shard_config: ShardConfig = None, ): r""" - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are - ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` - past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - """ + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False outputs = BertPipelineForwards.bert_model_forward( @@ -428,11 +425,9 @@ def bert_lm_head_model_forward( stage_manager=stage_manager, hidden_states=hidden_states if hidden_states is not None else None, stage_index=stage_index, - shard_config=shard_config) + shard_config=shard_config, + ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -459,9 +454,9 @@ def bert_lm_head_model_forward( cross_attentions=outputs.cross_attentions, ) else: - hidden_states = outputs.get('hidden_states') + hidden_states = outputs.get("hidden_states") # intermediate stage always return dict - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def bert_for_masked_lm_forward( @@ -484,20 +479,20 @@ def bert_for_masked_lm_forward( shard_config: ShardConfig = None, ): r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., - config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the - loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` - """ + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False outputs = BertPipelineForwards.bert_model_forward( @@ -525,7 +520,7 @@ def bert_for_masked_lm_forward( masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -539,8 +534,8 @@ def bert_for_masked_lm_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bert_for_next_sentence_prediction_forward( @@ -563,33 +558,33 @@ def bert_for_next_sentence_prediction_forward( ): # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair - (see `input_ids` docstring). Indices should be in `[0, 1]`: + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: - - 0 indicates sequence B is a continuation of sequence A, - - 1 indicates sequence B is a random sequence. + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, BertForNextSentencePrediction - >>> import torch + ```python + >>> from transformers import AutoTokenizer, BertForNextSentencePrediction + >>> import torch - >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") + >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + >>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") - >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." - >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") - >>> outputs = model(**encoding, labels=torch.LongTensor([1])) - >>> logits = outputs.logits - >>> assert logits[0, 0] < logits[0, 1] # next sentence was random - ``` - """ + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ logger = logging.get_logger(__name__) if "next_sentence_label" in kwargs: @@ -603,26 +598,28 @@ def bert_for_next_sentence_prediction_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=shard_config) + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -644,9 +641,9 @@ def bert_for_next_sentence_prediction_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') + hidden_states = outputs.get("hidden_states") # intermediate stage always return dict - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def bert_for_sequence_classification_forward( @@ -677,26 +674,28 @@ def bert_for_sequence_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=shard_config) + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) if stage_manager.is_last_stage(): pooled_output = outputs[1] @@ -737,8 +736,8 @@ def bert_for_sequence_classification_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bert_for_token_classification_forward( @@ -767,26 +766,28 @@ def bert_for_token_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=shard_config) + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -810,8 +811,8 @@ def bert_for_token_classification_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bert_for_multiple_choice_forward( @@ -842,10 +843,10 @@ def bert_for_multiple_choice_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # in our pipeline design,input ids are copied for every stage and shouldn't be none @@ -857,8 +858,11 @@ def bert_for_multiple_choice_forward( attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None else None) + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = BertPipelineForwards.bert_model_forward( self.bert, @@ -898,8 +902,8 @@ def bert_for_multiple_choice_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bert_for_question_answering_forward( @@ -936,26 +940,28 @@ def bert_for_question_answering_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - outputs = BertPipelineForwards.bert_model_forward(self.bert, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - hidden_states=hidden_states, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=shard_config) + outputs = BertPipelineForwards.bert_model_forward( + self.bert, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + hidden_states=hidden_states, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=shard_config, + ) if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -993,12 +999,11 @@ def bert_for_question_answering_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} def get_bert_flash_attention_forward(): - try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -1064,7 +1069,7 @@ def forward( distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) @@ -1084,19 +1089,17 @@ def forward( if final_attention_mask is not None: batch_size, src_len = query_layer.size()[0], query_layer.size()[2] tgt_len = key_layer.size()[2] - final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, - tgt_len).contiguous() + final_attention_mask = final_attention_mask.expand( + batch_size, self.num_attention_heads, src_len, tgt_len + ).contiguous() query_layer = query_layer.permute(0, 2, 1, 3).contiguous() key_layer = key_layer.permute(0, 2, 1, 3).contiguous() value_layer = value_layer.permute(0, 2, 1, 3).contiguous() - context_layer = me_attention(query_layer, - key_layer, - value_layer, - attn_bias=final_attention_mask, - p=self.dropout.p, - scale=scale) + context_layer = me_attention( + query_layer, key_layer, value_layer, attn_bias=final_attention_mask, p=self.dropout.p, scale=scale + ) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) @@ -1110,7 +1113,6 @@ def forward( def get_jit_fused_bert_self_output_forward(): - from transformers.models.bert.modeling_bert import BertSelfOutput def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: @@ -1123,7 +1125,6 @@ def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: tor def get_jit_fused_bert_output_forward(): - from transformers.models.bert.modeling_bert import BertOutput def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: @@ -1136,7 +1137,6 @@ def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.T def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): - def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -1174,8 +1174,9 @@ def forward( `past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: @@ -1241,12 +1242,13 @@ def forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - embedding_output = split_forward_gather_backward(embedding_output, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + embedding_output = split_forward_gather_backward( + embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( - encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group) + encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) encoder_outputs = self.encoder( embedding_output, @@ -1264,9 +1266,9 @@ def forward( sequence_output = encoder_outputs[0] # When sequence parallelism done, gather the output tensor in forward and split it in backward - sequence_output = gather_forward_split_backward(sequence_output, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + sequence_output = gather_forward_split_backward( + sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group + ) pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index 69730fd3d254..00b2037fbdc8 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -1,12 +1,10 @@ -import math -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.nn as nn def forward_fn(): - def forward( self, hidden_states: torch.Tensor, @@ -62,7 +60,6 @@ def forward( def get_blip2_flash_attention_forward(): - from transformers.models.blip_2.modeling_blip_2 import Blip2Attention from colossalai.kernel.cuda_native import ColoAttention @@ -80,10 +77,9 @@ def forward( mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] - attention = ColoAttention(embed_dim=self.embed_dim, - num_heads=self.num_heads, - dropout=self.dropout.p, - scale=self.scale) + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale + ) context_layer = attention(query_states, key_states, value_states) output = self.projection(context_layer) @@ -95,7 +91,6 @@ def forward( def get_jit_fused_blip2_QFormer_self_output_forward(): - from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: @@ -108,7 +103,6 @@ def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_ten def get_jit_fused_blip2_QFormer_output_forward(): - from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 66f24dc6088b..1bf87e80a461 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -30,9 +30,9 @@ def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: - - def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, - dtype: torch.dtype) -> torch.Tensor: + def build_bloom_alibi_tensor( + self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype + ) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value @@ -56,23 +56,23 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, num_heads = num_heads * world_size batch_size, seq_length = attention_mask.shape - closest_power_of_2 = 2**math.floor(math.log2(num_heads)) - base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32) + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != num_heads: - extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32) + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32, + ) num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, - 1 + 2 * num_remaining_heads, - 2, - device=attention_mask.device, - dtype=torch.int32) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32 + ) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) # Note: alibi will added to the attention bias that will be applied to the query, key product of attention @@ -87,7 +87,7 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, num_heads_per_rank = int(num_heads / dist.get_world_size(process_group)) offset = dist.get_rank(process_group) * num_heads_per_rank alibi = alibi.view(batch_size, num_heads, 1, seq_length) - alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] + alibi = alibi[:, offset : num_heads_per_rank + offset, :, :] return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) else: return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) @@ -96,9 +96,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, class BloomPipelineForwards: - ''' + """ This class serves as a micro library for bloom pipeline forwards. - ''' + """ @staticmethod def bloom_model_forward( @@ -117,8 +117,7 @@ def bloom_model_forward( stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']: - + ) -> Union[Tuple[torch.Tensor, ...], "BaseModelOutputWithPastAndCrossAttentions"]: logger = logging.get_logger(__name__) if deprecated_arguments.pop("position_ids", False) is not False: @@ -132,20 +131,21 @@ def bloom_model_forward( raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # add warnings here if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -184,7 +184,8 @@ def bloom_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False if past_key_values is None: @@ -193,7 +194,7 @@ def bloom_model_forward( seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] # source_len + past_key_values_length = past_key_values[0][0].shape[2] # source_len seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: @@ -213,20 +214,20 @@ def bloom_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) start_idx, end_idx = stage_index[0], stage_index[1] - for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), - start=start_idx): + for i, (block, layer_past) in enumerate( + zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx + ): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) @@ -257,14 +258,13 @@ def custom_forward(*inputs): if use_cache is True: presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + \ - (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if stage_manager.is_last_stage(): # Add last hidden state @@ -277,7 +277,8 @@ def custom_forward(*inputs): if stage_manager.is_last_stage(): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + ) # attention_mask is not returned ; presents = past_key_values return BaseModelOutputWithPastAndCrossAttentions( @@ -288,25 +289,27 @@ def custom_forward(*inputs): ) else: # always return dict for imediate stage - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod - def bloom_for_causal_lm_forward(self: BloomForCausalLM, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None, - **deprecated_arguments): + def bloom_for_causal_lm_forward( + self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + **deprecated_arguments, + ): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set @@ -328,30 +331,29 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False - transformer_outputs = BloomPipelineForwards.bloom_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + transformer_outputs = BloomPipelineForwards.bloom_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) @@ -366,8 +368,9 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), - shift_labels.view(batch_size * seq_length)) + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -381,8 +384,8 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bloom_for_sequence_classification_forward( @@ -425,10 +428,10 @@ def bloom_for_sequence_classification_forward( # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False transformer_outputs = BloomPipelineForwards.bloom_model_forward( @@ -448,9 +451,6 @@ def bloom_for_sequence_classification_forward( shard_config=shard_config, ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): batch_size = hidden_states.shape[0] # update batch size @@ -468,7 +468,8 @@ def bloom_for_sequence_classification_forward( sequence_lengths = -1 logger.warning( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] @@ -506,8 +507,8 @@ def bloom_for_sequence_classification_forward( attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bloom_for_token_classification_forward( @@ -550,10 +551,10 @@ def bloom_for_token_classification_forward( # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False transformer_outputs = BloomPipelineForwards.bloom_model_forward( @@ -573,9 +574,6 @@ def bloom_for_token_classification_forward( shard_config=shard_config, ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -588,8 +586,9 @@ def bloom_for_token_classification_forward( labels = labels.to(logits.device) batch_size, seq_length = labels.shape loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), - labels.view(batch_size * seq_length)) + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) if not return_dict: output = (logits,) + transformer_outputs[2:] @@ -602,8 +601,8 @@ def bloom_for_token_classification_forward( attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def bloom_for_question_answering_forward( @@ -638,10 +637,10 @@ def bloom_for_question_answering_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False outputs = BloomPipelineForwards.bloom_model_forward( @@ -659,10 +658,6 @@ def bloom_for_question_answering_forward( stage_index=stage_index, shard_config=shard_config, ) - past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): sequence_output = outputs[0] @@ -700,12 +695,11 @@ def bloom_for_question_answering_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} def get_bloom_flash_attention_forward(enabel_jit_fused=False): - try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -723,7 +717,6 @@ def forward( use_cache: bool = False, output_attentions: bool = False, ): - fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) batch_size, tgt_len, _ = query_layer.size() @@ -750,29 +743,35 @@ def forward( tgt_len = key_layer.size()[1] - attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length), - dtype=torch.float32, - device=query_layer.device, - requires_grad=True) - attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, - kv_length) * self.beta - attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask, - torch.finfo(torch.float32).min) - - context_layer = me_attention(query_layer, - key_layer, - value_layer, - attn_bias=attention_numerical_mask, - scale=self.inv_norm_factor, - p=self.attention_dropout.p) + attention_numerical_mask = torch.zeros( + (batch_size, self.num_heads, tgt_len, kv_length), + dtype=torch.float32, + device=query_layer.device, + requires_grad=True, + ) + attention_numerical_mask = ( + attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1, kv_length) * self.beta + ) + attention_numerical_mask = torch.masked_fill( + attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min + ) + + context_layer = me_attention( + query_layer, + key_layer, + value_layer, + attn_bias=attention_numerical_mask, + scale=self.inv_norm_factor, + p=self.attention_dropout.p, + ) context_layer = context_layer.reshape(-1, kv_length, self.hidden_size) if self.pretraining_tp > 1 and self.slow_but_exact: slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices):int((i + 1) * slices)], - self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) @@ -787,7 +786,6 @@ def forward( def get_jit_fused_bloom_attention_forward(): - from transformers.models.bloom.modeling_bloom import BloomAttention def forward( @@ -801,7 +799,7 @@ def forward( use_cache: bool = False, output_attentions: bool = False, ): - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) @@ -867,8 +865,8 @@ def forward( output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices):int((i + 1) * slices)], - self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) @@ -885,7 +883,6 @@ def forward( def get_jit_fused_bloom_mlp_forward(): - from transformers.models.bloom.modeling_bloom import BloomMLP def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: @@ -896,8 +893,8 @@ def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp for i in range(self.pretraining_tp): intermediate_output = intermediate_output + F.linear( - hidden_states[:, :, int(i * slices):int((i + 1) * slices)], - self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)], + hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: intermediate_output = self.dense_4h_to_h(hidden_states) @@ -908,7 +905,6 @@ def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) def get_jit_fused_bloom_gelu_forward(): - from transformers.models.bloom.modeling_bloom import BloomGelu from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction @@ -924,7 +920,6 @@ def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor: def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): - from transformers import BloomModel def forward( @@ -951,8 +946,9 @@ def forward( raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -986,7 +982,8 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # Compute alibi tensor: check build_alibi_tensor documentation @@ -1009,9 +1006,9 @@ def forward( ) # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: @@ -1020,7 +1017,6 @@ def forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) @@ -1054,9 +1050,9 @@ def custom_forward(*inputs): all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) # Add last hidden state hidden_states = self.ln_f(hidden_states) diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 16dcf87c8cfc..8934068d609c 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -1,26 +1,19 @@ """ PyTorch ChatGLM model. """ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple import torch -import torch.nn.functional as F import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward -from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( - ChatGLMForConditionalGeneration, - ChatGLMModel, - GLMBlock, -) +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel def get_flash_core_attention_forward(): - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention @@ -30,15 +23,15 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_ if pytorch_major_version >= 2: query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - is_causal=True) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True + ) else: if attention_mask is not None: attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) @@ -60,15 +53,15 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_ flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal - attention = ColoAttention(embed_dim=self.hidden_size_per_partition, - num_heads=self.num_attention_heads_per_partition, - dropout=self.attention_dropout.p, - scale=scale) - context_layer = attention(query_layer, - key_layer, - value_layer, - attn_mask=flash_attention_mask, - attn_mask_type=attn_mask_type) + attention = ColoAttention( + embed_dim=self.hidden_size_per_partition, + num_heads=self.num_attention_heads_per_partition, + dropout=self.attention_dropout.p, + scale=scale, + ) + context_layer = attention( + query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) context_layer = context_layer.permute(1, 0, -1).contiguous() @@ -78,7 +71,6 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_ def get_jit_fused_glm_block_forward(): - from .chatglm2_6b.modeling_chatglm import GLMBlock def forward( @@ -129,9 +121,9 @@ def forward( class ChatGLMPipelineForwards: - ''' + """ This class serves as a micro library for ChatGLM model forwards under pipeline parallelism. - ''' + """ @staticmethod def chatglm_model_forward( @@ -151,19 +143,20 @@ def chatglm_model_forward( shard_config: ShardConfig = None, ): logger = logging.get_logger(__name__) - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if stage_manager.is_first_stage(): batch_size, seq_length = input_ids.shape @@ -174,12 +167,13 @@ def chatglm_model_forward( seq_length, batch_size = hidden_states.shape[:2] if self.pre_seq_len is not None: if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, - device=input_ids.device, - dtype=inputs_embeds.dtype) + past_key_values = self.get_prompt( + batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + ) if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], - dim=-1) + attention_mask = torch.cat( + [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + ) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) @@ -196,37 +190,41 @@ def chatglm_model_forward( if self.encoder.gradient_checkpointing and self.encoder.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False all_self_attentions = None all_hidden_states = () if output_hidden_states else None start_idx, end_idx = stage_index[0], stage_index[1] if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward(hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.encoder.gradient_checkpointing and self.encoder.training: - layer_ret = torch.utils.checkpoint.checkpoint(layer, hidden_states, attention_mask, rotary_pos_emb, - past_key_values[idx], use_cache) + layer_ret = torch.utils.checkpoint.checkpoint( + layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache + ) else: - layer_ret = layer(hidden_states, - full_attention_mask, - rotary_pos_emb, - kv_cache=past_key_values[idx], - use_cache=use_cache) + layer_ret = layer( + hidden_states, + full_attention_mask, + rotary_pos_emb, + kv_cache=past_key_values[idx], + use_cache=use_cache, + ) hidden_states, kv_cache = layer_ret if use_cache: presents = presents + (kv_cache,) if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward(hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): @@ -235,7 +233,8 @@ def chatglm_model_forward( hidden_states = self.encoder.final_layernorm(hidden_states) if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, @@ -243,28 +242,30 @@ def chatglm_model_forward( attentions=all_self_attentions, ) else: - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod - def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None): - logger = logging.get_logger(__name__) + def chatglm_for_conditional_generation_forward( + self: ChatGLMForConditionalGeneration, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + logging.get_logger(__name__) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = ChatGLMPipelineForwards.chatglm_model_forward( self.transformer, input_ids=input_ids, @@ -312,7 +313,6 @@ def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGenera def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): - def forward( self, input_ids, @@ -325,10 +325,11 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape @@ -365,9 +366,9 @@ def forward( # Run encoder. # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] - inputs_embeds = split_forward_gather_backward(inputs_embeds, - dim=0, - process_group=shard_config.tensor_parallel_process_group) + inputs_embeds = split_forward_gather_backward( + inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group + ) hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, @@ -377,17 +378,21 @@ def forward( output_hidden_states=output_hidden_states, ) - hidden_states = gather_forward_split_backward(hidden_states, - dim=0, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + ) if not return_dict: - return tuple(v for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, diff --git a/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py index 3e78732be2da..bb774676a4d4 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/configuration_chatglm.py @@ -4,32 +4,34 @@ class ChatGLMConfig(PretrainedConfig): model_type = "chatglm" - def __init__(self, - num_layers=28, - padded_vocab_size=65024, - hidden_size=4096, - ffn_hidden_size=13696, - kv_channels=128, - num_attention_heads=32, - seq_length=2048, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - rmsnorm=True, - apply_residual_connection_post_layernorm=False, - post_layer_norm=True, - add_bias_linear=False, - add_qkv_bias=False, - bias_dropout_fusion=True, - multi_query_attention=False, - multi_query_group_num=1, - apply_query_key_layer_scaling=True, - attention_softmax_in_fp32=True, - fp32_residual_connection=False, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs): + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, + ): self.num_layers = num_layers self.vocab_size = padded_vocab_size self.padded_vocab_size = padded_vocab_size diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index a21ee0231422..3a8d90ec7328 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -37,10 +37,9 @@ import copy import math -import re import sys import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -80,7 +79,6 @@ def default_init(cls, *args, **kwargs): class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() @@ -100,7 +98,7 @@ def __init__(self, config: ChatGLMConfig): self.prefix_projection = config.prefix_projection if self.prefix_projection: # Use a two-layer MLP to encode the prefix - kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) self.trans = torch.nn.Sequential( torch.nn.Linear(kv_size, config.hidden_size), @@ -151,10 +149,9 @@ def split_tensor_along_last_dim( class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): super().__init__() - inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) self.register_buffer("inv_freq", inv_freq) self.dim = dim self.original_impl = original_impl @@ -174,7 +171,7 @@ def forward_impl( https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. """ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` seq_idx = torch.arange(seq_len, dtype=dtype, device=device) @@ -220,7 +217,6 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): super().__init__() self.elementwise_affine = True @@ -236,7 +232,6 @@ def forward(self, hidden_states: torch.Tensor): class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() @@ -250,7 +245,7 @@ def __init__(self, config: ChatGLMConfig, layer_number): # Per attention head and per partition values. self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads) + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads coeff = None @@ -267,15 +262,15 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): if pytorch_major_version >= 2: query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - is_causal=True) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True + ) else: if attention_mask is not None: attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) @@ -307,8 +302,8 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), ) @@ -325,7 +320,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): attention_scores = attention_scores.float() if self.coeff is not None: attention_scores = attention_scores * self.coeff - if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]): + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: attention_mask = torch.ones( output_size[0], 1, @@ -388,15 +383,16 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): self.projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. - self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads) + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads self.multi_query_attention = config.multi_query_attention self.qkv_hidden_size = 3 * self.projection_size if self.multi_query_attention: self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = (self.projection_size + - 2 * self.hidden_size_per_attention_head * config.multi_query_group_num) + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) self.query_key_value = nn.Linear( config.hidden_size, self.qkv_hidden_size, @@ -459,18 +455,27 @@ def forward( ], dim=-1, ) - query_layer = query_layer.view(query_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) - key_layer = key_layer.view(key_layer.size()[:-1] + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - )) - value_layer = value_layer.view(value_layer.size()[:-1] + ( - self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head, - )) + query_layer = query_layer.view( + query_layer.size()[:-1] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + ( + self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head, + ) + ) else: new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, @@ -504,10 +509,13 @@ def forward( self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, ) - key_layer = key_layer.contiguous().view(key_layer.size()[:2] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand( -1, @@ -516,10 +524,13 @@ def forward( self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, ) - value_layer = value_layer.contiguous().view(value_layer.size()[:2] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - )) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + ) # ================================== # core attention computation @@ -600,7 +611,7 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(GLMBlock, self).__init__() self.layer_number = layer_number - self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm) + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm self.fp32_residual_connection = config.fp32_residual_connection @@ -724,7 +735,8 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False all_self_attentions = None @@ -806,7 +818,7 @@ def get_masks(self, input_ids, past_key_values, padding_mask=None): def get_position_ids(self, input_ids, device): batch_size, seq_length = input_ids.shape - position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)) + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) return position_ids def _set_gradient_checkpointing(self, module, value=False): @@ -843,7 +855,6 @@ def forward(self, input_ids): class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): super().__init__(config) if empty_init: @@ -860,8 +871,9 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): # Rotary positional embeddings self.seq_length = config.seq_length - rotary_dim = (config.hidden_size // - config.num_attention_heads if config.kv_channels is None else config.kv_channels) + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) self.rotary_pos_emb = RotaryEmbedding( rotary_dim // 2, @@ -891,7 +903,7 @@ def get_input_embeddings(self): return self.embedding.word_embeddings def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)) + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) past_key_values = past_key_values.view( batch_size, @@ -917,10 +929,11 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape @@ -966,12 +979,16 @@ def forward( ) if not return_dict: - return tuple(v for v in [ - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -988,7 +1005,6 @@ def quantize(self, weight_bit_width: int): class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): super().__init__(config) @@ -1009,7 +1025,8 @@ def _update_model_kwargs_for_generation( ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format) + outputs, standardize_cache_format=standardize_cache_format + ) # update attention mask if "attention_mask" in model_kwargs: @@ -1067,7 +1084,7 @@ def forward( return_last_logit: Optional[bool] = False, ): use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids=input_ids, @@ -1113,8 +1130,9 @@ def forward( ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], - beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -1122,10 +1140,13 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], Output shares the same memory storage as `past`. """ - return tuple(( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) for layer_past in past) + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) def process_response(self, response): response = response.strip() @@ -1180,7 +1201,7 @@ def chat( } inputs = self.build_inputs(tokenizer, query, history=history) outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] response = tokenizer.decode(outputs) response = self.process_response(response) history = history + [(query, response)] @@ -1227,14 +1248,14 @@ def stream_chat( attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) inputs["attention_mask"] = attention_mask for outputs in self.stream_generate( - **inputs, - past_key_values=past_key_values, - return_past_key_values=return_past_key_values, - **gen_kwargs, + **inputs, + past_key_values=past_key_values, + return_past_key_values=return_past_key_values, + **gen_kwargs, ): if return_past_key_values: outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :] response = tokenizer.decode(outputs) if response and response[-1] != "�": response = self.process_response(response) @@ -1269,7 +1290,7 @@ def stream_generate( if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None) + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " @@ -1278,7 +1299,7 @@ def stream_generate( UserWarning, ) elif generation_config.max_new_tokens is not None: - generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length) + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length if not has_default_max_length: logger.warn( f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" @@ -1289,14 +1310,16 @@ def stream_generate( ) if input_ids_seq_length >= generation_config.max_length: - input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids") - logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`.") + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) # 2. Set generation parameters if not already defined - logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList()) - stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() logits_processor = self._get_logits_processor( generation_config=generation_config, @@ -1306,8 +1329,9 @@ def stream_generate( logits_processor=logits_processor, ) - stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, - stopping_criteria=stopping_criteria) + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) logits_warper = self._get_logits_warper(generation_config) unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) @@ -1337,9 +1361,9 @@ def stream_generate( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation(outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) if return_past_key_values: yield input_ids, outputs.past_key_values diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 84deafefeadd..21f06393071d 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -26,32 +26,32 @@ class GPT2PipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of GPT2 models under pipeline setting. - ''' + """ @staticmethod def gpt2_model_forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. @@ -62,16 +62,16 @@ def gpt2_model_forward( # Preprocess passed in arguments # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if stage_manager.is_first_stage(): @@ -115,7 +115,7 @@ def gpt2_model_forward( # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention @@ -156,7 +156,8 @@ def gpt2_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False presents = () if use_cache else None all_self_attentions = () if output_attentions else None @@ -166,9 +167,9 @@ def gpt2_model_forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config.enable_sequence_parallelism: - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) # Going through held blocks. start_idx, end_idx = stage_index[0], stage_index[1] @@ -186,7 +187,6 @@ def gpt2_model_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) @@ -225,9 +225,9 @@ def custom_forward(*inputs): # When sequence parallelism done, gather the output tensor in forward and split it in backward if shard_config.enable_sequence_parallelism: - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) if stage_manager.is_last_stage(): hidden_states = self.ln_f(hidden_states) @@ -241,8 +241,10 @@ def custom_forward(*inputs): if stage_manager.is_last_stage(): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None) + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, @@ -253,62 +255,65 @@ def custom_forward(*inputs): ) else: # always return dict for intermediate stage - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def gpt2_lmhead_model_forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) @@ -337,25 +342,26 @@ def gpt2_lmhead_model_forward( @staticmethod def gpt2_double_heads_model_forward( - self: GPT2DoubleHeadsModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - mc_token_ids: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - mc_labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: + self: GPT2DoubleHeadsModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]: r""" mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - @@ -373,26 +379,28 @@ def gpt2_double_heads_model_forward( ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) @@ -428,22 +436,23 @@ def gpt2_double_heads_model_forward( @staticmethod def gpt2_for_question_answering_forward( - self: GPT2ForQuestionAnswering, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: + self: GPT2ForQuestionAnswering, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]: r""" start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for position (index) of the start of the labelled span for computing the token classification loss. @@ -459,24 +468,26 @@ def gpt2_for_question_answering_forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} sequence_output = outputs[0] @@ -516,23 +527,24 @@ def gpt2_for_question_answering_forward( @staticmethod def gpt2_for_token_classification_forward( - self: GPT2ForTokenClassification, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]: + self: GPT2ForTokenClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -544,26 +556,28 @@ def gpt2_for_token_classification_forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) @@ -588,23 +602,24 @@ def gpt2_for_token_classification_forward( @staticmethod def gpt2_for_sequence_classification_forward( - self: GPT2ForSequenceClassification, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: + self: GPT2ForSequenceClassification, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -613,38 +628,41 @@ def gpt2_for_sequence_classification_forward( # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification.forward. # Please refer to original code of transformers for more details. - """ + """ logger = logging.get_logger(__name__) if input_ids is not None: batch_size, _ = input_ids.shape[:2] else: batch_size, _ = hidden_states.shape[:2] - assert (self.config.pad_token_id is not None - or batch_size == 1), "Cannot handle batch sizes > 1 if no padding token is defined." + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - shard_config=shard_config) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + shard_config=shard_config, + ) # If not at the last stage, return hidden_states as in GPT2Model if not stage_manager.is_last_stage(): - return {'hidden_states': outputs['hidden_states']} + return {"hidden_states": outputs["hidden_states"]} hidden_states = outputs[0] logits = self.score(hidden_states) @@ -658,7 +676,8 @@ def gpt2_for_sequence_classification_forward( sequence_lengths = -1 logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] @@ -698,7 +717,6 @@ def gpt2_for_sequence_classification_forward( def get_gpt2_flash_attention_forward(): - from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention @@ -722,12 +740,12 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.") + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) query = self.q_attn(hidden_states) key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) @@ -759,15 +777,14 @@ def forward( attn_mask_type = AttnMaskType.padding flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - scale = value.size(-1)**-0.5 + scale = value.size(-1) ** -0.5 if self.scale_attn_by_inverse_layer_idx: scale = scale * (1 / float(self.layer_idx + 1)) # use coloattention - attention = ColoAttention(embed_dim=self.embed_dim, - num_heads=self.num_heads, - dropout=self.attn_dropout.p, - scale=scale) + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale + ) attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) @@ -781,7 +798,6 @@ def forward( def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -799,8 +815,9 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -849,7 +866,7 @@ def forward( # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention @@ -886,7 +903,8 @@ def forward( if use_cache: logger = logging.get_logger(__name__) logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False presents = () if use_cache else None @@ -896,9 +914,9 @@ def forward( # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] - hidden_states = split_forward_gather_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = split_forward_gather_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel @@ -918,7 +936,6 @@ def forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) @@ -962,9 +979,9 @@ def custom_forward(*inputs): hidden_states = hidden_states.to("cuda:" + str(k + 1)) # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward(hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group) + hidden_states = gather_forward_split_backward( + hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group + ) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) @@ -974,8 +991,10 @@ def custom_forward(*inputs): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None) + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, diff --git a/colossalai/shardformer/modeling/jit.py b/colossalai/shardformer/modeling/jit.py index 6434348ef823..c92847a3fbcc 100644 --- a/colossalai/shardformer/modeling/jit.py +++ b/colossalai/shardformer/modeling/jit.py @@ -2,7 +2,6 @@ def get_dropout_add_func(): - from transformers.models.bloom.modeling_bloom import dropout_add def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: @@ -12,7 +11,6 @@ def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, def get_jit_fused_dropout_add_func(): - from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: @@ -25,7 +23,6 @@ def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, def get_jit_fused_gelu_forward_func(): - from colossalai.kernel.jit.bias_gelu import bias_gelu def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index ff622c306c59..4b6c8342534a 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -15,10 +15,10 @@ class LlamaPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Llama models under pipeline setting. - ''' + """ @staticmethod def llama_model_forward( @@ -39,8 +39,9 @@ def llama_model_forward( logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -69,13 +70,13 @@ def llama_model_forward( # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if past_key_values is not None: @@ -83,10 +84,9 @@ def llama_model_forward( seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -94,16 +94,18 @@ def llama_model_forward( # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=hidden_states.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, - past_key_values_length) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # decoder layers @@ -121,7 +123,6 @@ def llama_model_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -169,7 +170,7 @@ def custom_forward(*inputs): attentions=all_self_attns, ) # always return dict for imediate stage - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def llama_for_causal_lm_forward( @@ -189,42 +190,43 @@ def llama_for_causal_lm_forward( stage_index: Optional[List[int]] = None, ): r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -244,9 +246,6 @@ def llama_for_causal_lm_forward( stage_index=stage_index, ) past_key_values = None - all_hidden_states = None - all_self_attentions = None - all_cross_attentions = None if stage_manager.is_last_stage(): hidden_states = outputs[0] @@ -276,8 +275,8 @@ def llama_for_causal_lm_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def llama_for_sequence_classification_forward( @@ -307,10 +306,10 @@ def llama_for_sequence_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False transformer_outputs = LlamaPipelineForwards.llama_model_forward( @@ -388,16 +387,15 @@ def llama_for_sequence_classification_forward( ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} def get_llama_flash_attention_forward(): - - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + llama_version = 2 try: from transformers.models.llama.modeling_llama import repeat_kv @@ -453,16 +451,15 @@ def forward( if attention_mask != None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) - attn_output = attention(query_states, - key_states, - value_states, - attn_mask=flash_attention_mask, - attn_mask_type=attn_mask_type) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) attn_output = self.o_proj(attn_output) diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index ad088f3702e5..e0978d38e110 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -21,16 +21,17 @@ class OPTPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of OPT models under pipeline setting. - ''' + """ @staticmethod def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] from transformers.models.opt.modeling_opt import _make_causal_mask + combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( @@ -42,10 +43,12 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, - tgt_len=input_shape[-1]).to(device) - combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + - combined_attention_mask) + expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to( + device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask @@ -79,17 +82,19 @@ def opt_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: - ''' + """ This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward - ''' + """ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.utils import logging + logger = logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -133,10 +138,12 @@ def opt_model_forward( elif attention_mask.shape[1] != mask_seq_length: raise ValueError( f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)") + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) - causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, - device, past_key_values_length) + causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask( + attention_mask, input_shape, _dtype, device, past_key_values_length + ) if stage_manager.is_first_stage(): pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) @@ -145,21 +152,22 @@ def opt_model_forward( if decoder.gradient_checkpointing and decoder.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # decoder layers @@ -173,7 +181,8 @@ def opt_model_forward( if attn_mask.size()[0] != (len(decoder.layers)): raise ValueError( f"The `{mask_name}` should be specified for {len(decoder.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") + f" {head_mask.size()[0]}." + ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -195,7 +204,6 @@ def opt_model_forward( if decoder.gradient_checkpointing and decoder.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -250,7 +258,7 @@ def custom_forward(*inputs): attentions=all_self_attns, ) else: - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} @staticmethod def opt_for_causal_lm_forward( @@ -275,8 +283,9 @@ def opt_for_causal_lm_forward( """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -319,8 +328,8 @@ def opt_for_causal_lm_forward( attentions=outputs.attentions, ) else: - hidden_states = outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def opt_for_sequence_classification_forward( @@ -348,19 +357,21 @@ def opt_for_sequence_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) + transformer_outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -377,7 +388,8 @@ def opt_for_sequence_classification_forward( sequence_lengths = -1 logger.warning( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`") + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] @@ -416,8 +428,8 @@ def opt_for_sequence_classification_forward( attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} @staticmethod def opt_for_question_answering_forward( @@ -443,19 +455,21 @@ def opt_for_question_answering_forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = OPTPipelineForwards.opt_model_forward(self.model, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) + transformer_outputs = OPTPipelineForwards.opt_model_forward( + self.model, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) if stage_manager.is_last_stage(): hidden_states = transformer_outputs[0] @@ -493,12 +507,11 @@ def opt_for_question_answering_forward( attentions=transformer_outputs.attentions, ) else: - hidden_states = transformer_outputs.get('hidden_states') - return {'hidden_states': hidden_states} + hidden_states = transformer_outputs.get("hidden_states") + return {"hidden_states": hidden_states} def get_opt_flash_attention_forward(): - from transformers.models.opt.modeling_opt import OPTAttention from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention @@ -555,27 +568,27 @@ def forward( src_len = key_states.size(1) if layer_head_mask != None: if layer_head_mask.size() != (self.num_heads,): - raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) flash_attention_mask = None attn_mask_type = AttnMaskType.causal if attention_mask != None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}") + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() attn_mask_type = AttnMaskType.paddedcausal - attention = ColoAttention(embed_dim=self.embed_dim, - num_heads=self.num_heads, - dropout=self.dropout, - scale=self.scaling) - attn_output = attention(query_states, - key_states, - value_states, - attn_mask=flash_attention_mask, - attn_mask_type=attn_mask_type) + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling + ) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type + ) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value @@ -584,7 +597,6 @@ def forward( def get_jit_fused_opt_decoder_layer_forward(): - from transformers.models.opt.modeling_opt import OPTDecoderLayer def forward( diff --git a/colossalai/shardformer/modeling/sam.py b/colossalai/shardformer/modeling/sam.py index c40c02ec411a..26e0b224d3ab 100644 --- a/colossalai/shardformer/modeling/sam.py +++ b/colossalai/shardformer/modeling/sam.py @@ -7,20 +7,23 @@ def forward_fn(): - def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, batch_size, nHead, height * width, channel) - qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, - -1).permute(2, 0, 3, 1, 4)) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) # q, k, v with shape (batch_size * nHead, height * width, channel) query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) attn_weights = (query * self.scale) @ key.transpose(-2, -1) if self.use_rel_pos: - attn_weights = self.add_decomposed_rel_pos(attn_weights, query, self.rel_pos_h, self.rel_pos_w, - (height, width), (height, width)) + attn_weights = self.add_decomposed_rel_pos( + attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) @@ -45,8 +48,8 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch def get_sam_flash_attention_forward(): - from transformers.models.sam.modeling_sam import SamAttention + try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -62,11 +65,9 @@ def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor: batch, n_tokens, n_heads, c_per_head = hidden_states.shape return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) - def forward(self: SamAttention, - query: Tensor, - key: Tensor, - value: Tensor, - attention_similarity: Tensor = None) -> Tensor: + def forward( + self: SamAttention, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None + ) -> Tensor: # Input projections query = self.q_proj(query) key = self.k_proj(key) @@ -96,8 +97,8 @@ def forward(self: SamAttention, def get_sam_vision_flash_attention_forward(): - from transformers.models.sam.modeling_sam import SamVisionAttention + try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -181,8 +182,11 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: batch_size, height, width, _ = hidden_states.shape # qkv with shape (3, batch_size, nHead, height * width, channel) - qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, - -1).permute(2, 0, 1, 3, 4)) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 1, 3, 4) + ) query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0) diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index 9cc071f91dfc..f67aa84e4e72 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -17,10 +17,10 @@ class T5PipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of T5 models under pipeline setting. - ''' + """ @staticmethod def t5_stack_forward( @@ -44,7 +44,6 @@ def t5_stack_forward( stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Stack.forward. # Please refer to original code of transformers for more details. @@ -52,16 +51,16 @@ def t5_stack_forward( # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False if use_cache is True: if not in_decoder: @@ -69,7 +68,8 @@ def t5_stack_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False stage = stage_manager.stage @@ -97,7 +97,8 @@ def t5_stack_forward( else: err_msg_prefix = "decoder_" if in_decoder else "" raise ValueError( - f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -108,7 +109,8 @@ def t5_stack_forward( else: if hidden_states is None: raise ValueError( - "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." + ) input_shape = hidden_states.size()[:-1] batch_size, seq_length = input_shape[0], input_shape[1] device = hidden_states.device @@ -153,7 +155,6 @@ def t5_stack_forward( start_idx, end_idx = stage_index[0], stage_index[1] for i in range(start_idx, end_idx): - past_key_value = past_key_values[i] layer_module = self.block[i] layer_head_mask = head_mask[i] @@ -163,7 +164,6 @@ def t5_stack_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): return tuple(module(*inputs, use_cache, output_attentions)) @@ -179,7 +179,7 @@ def custom_forward(*inputs): encoder_decoder_position_bias, layer_head_mask, cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing + None, # past_key_value is always None with gradient checkpointing ) else: layer_outputs = layer_module( @@ -220,13 +220,17 @@ def custom_forward(*inputs): hidden_states = self.dropout(hidden_states) if not return_dict: - return tuple(v for v in [ - hidden_states, - present_key_value_states, - all_hidden_states, - all_attentions, - all_cross_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, @@ -236,10 +240,10 @@ def custom_forward(*inputs): ) else: return { - 'hidden_states': hidden_states, - 'position_bias': position_bias, - 'encoder_decoder_position_bias': encoder_decoder_position_bias, - 'backward_tensor_keys': ['hidden_states'] + "hidden_states": hidden_states, + "position_bias": position_bias, + "encoder_decoder_position_bias": encoder_decoder_position_bias, + "backward_tensor_keys": ["hidden_states"], } @staticmethod @@ -269,7 +273,6 @@ def t5_model_forward( stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: - # This function is modified on the basis of transformers.models.t5.modeling_t5.T5Model.forward. # Please refer to original code of transformers for more details. @@ -287,16 +290,16 @@ def t5_model_forward( # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask @@ -322,10 +325,11 @@ def t5_model_forward( position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_hidden_states': encoder_outputs[0]} + return {"encoder_hidden_states": encoder_outputs[0]} else: return encoder_outputs @@ -360,23 +364,26 @@ def t5_model_forward( position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) # Directly return outputs of overloaded T5Stack forward if not at last stage. if not at_last_decoder_stage: # encoder_hidden_states should be passed to the next stage - decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + decoder_outputs["encoder_hidden_states"] = encoder_hidden_states return decoder_outputs if not return_dict: return decoder_outputs + encoder_hidden_states else: - return Seq2SeqModelOutput(last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_hidden_states) + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) @staticmethod def t5_for_conditional_generation_forward( @@ -406,7 +413,6 @@ def t5_for_conditional_generation_forward( stage_index: Optional[List[int]] = None, decoder_starting_stage: Optional[int] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: - # This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward. # Please refer to original code of transformers for more details. @@ -424,16 +430,16 @@ def t5_for_conditional_generation_forward( # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask @@ -460,10 +466,11 @@ def t5_for_conditional_generation_forward( position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_hidden_states': encoder_outputs[0]} + return {"encoder_hidden_states": encoder_outputs[0]} else: return encoder_outputs @@ -502,12 +509,13 @@ def t5_for_conditional_generation_forward( position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) # Directly return outputs of overloaded T5Stack forward if not at last stage. if not at_last_decoder_stage: # encoder_hidden_states should be passed to the next stage - decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + decoder_outputs["encoder_hidden_states"] = encoder_hidden_states return decoder_outputs sequence_output = decoder_outputs[0] @@ -530,13 +538,15 @@ def t5_for_conditional_generation_forward( output = (lm_logits,) + decoder_outputs[1:] + encoder_hidden_states return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput(loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_hidden_states) + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_hidden_states, + ) @staticmethod def t5_encoder_model_forward( @@ -562,26 +572,27 @@ def t5_encoder_model_forward( ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = T5PipelineForwards.t5_stack_forward(self.encoder, - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - position_bias=position_bias, - encoder_decoder_position_bias=encoder_decoder_position_bias, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + outputs = T5PipelineForwards.t5_stack_forward( + self.encoder, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + position_bias=position_bias, + encoder_decoder_position_bias=encoder_decoder_position_bias, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) return outputs def get_t5_flash_attention_forward(): - try: from xformers.ops import memory_efficient_attention as me_attention except: @@ -655,19 +666,21 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return hidden_states # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) # get key/value states - key_states = project(hidden_states, self.k, key_value_states, - past_key_value[0] if past_key_value is not None else None) - value_states = project(hidden_states, self.v, key_value_states, - past_key_value[1] if past_key_value is not None else None) + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) if position_bias is None: if not self.has_relative_attention_bias: - position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), - device=query_states.device, - dtype=query_states.dtype) + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=query_states.device, dtype=query_states.dtype + ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: @@ -676,10 +689,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1):, :] + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -689,12 +702,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias position_bias_masked = position_bias_masked.contiguous() - attn_output = me_attention(query_states, - key_states, - value_states, - attn_bias=position_bias_masked, - p=self.dropout, - scale=1.0) + attn_output = me_attention( + query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0 + ) attn_output = unshape(attn_output) attn_output = self.o(attn_output) @@ -708,7 +718,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): def get_jit_fused_T5_layer_ff_forward(): - from transformers.models.t5.modeling_t5 import T5LayerFF def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor: @@ -721,7 +730,6 @@ def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor: def get_T5_layer_self_attention_forward(): - from transformers.models.t5.modeling_t5 import T5LayerSelfAttention def forward( @@ -745,14 +753,13 @@ def forward( output_attentions=output_attentions, ) hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) - outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs return forward def get_T5_layer_cross_attention_forward(): - from transformers.models.t5.modeling_t5 import T5LayerCrossAttention def forward( @@ -780,7 +787,7 @@ def forward( output_attentions=output_attentions, ) layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training) - outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them return outputs return forward diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 2ce52163ac32..2db83b912112 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -1,5 +1,5 @@ import math -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder @@ -17,7 +17,6 @@ def _encoder_forward( return_dict: bool = True, stage_manager: PipelineStageManager = None, ) -> Union[tuple, BaseModelOutput]: - for i in range(start_idx, end_idx): layer_module = encoder.layer[i] @@ -26,7 +25,6 @@ def _encoder_forward( if encoder.gradient_checkpointing and encoder.training: def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, False) @@ -54,7 +52,6 @@ def custom_forward(*inputs): def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): - from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling def pp_forward( @@ -69,19 +66,19 @@ def pp_forward( hidden_states: Optional[torch.FloatTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" - bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - """ + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict logger = logging.get_logger(__name__) # Preprocess passed in arguments if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # Prepare head mask if needed @@ -100,11 +97,13 @@ def pp_forward( if pixel_values.dtype != expected_dtype: pixel_values = pixel_values.to(expected_dtype) - embedding_output = self.embeddings(pixel_values, - bool_masked_pos=bool_masked_pos, - interpolate_pos_encoding=interpolate_pos_encoding) + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) else: - assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + assert ( + hidden_states is not None + ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" # Go through encoder if not stage_manager.is_last_stage(): @@ -117,7 +116,7 @@ def pp_forward( return_dict=return_dict, stage_manager=stage_manager, ) - return {'hidden_states': hidden_states} + return {"hidden_states": hidden_states} else: encoder_outputs = _encoder_forward( encoder=self.encoder, @@ -149,7 +148,6 @@ def pp_forward( def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): - from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.models.vit.modeling_vit import ImageClassifierOutput @@ -173,7 +171,9 @@ def pp_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if not stage_manager.is_first_stage(): - assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + assert ( + hidden_states is not None + ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" outputs = self.vit( pixel_values, @@ -234,7 +234,6 @@ def pp_forward( def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): - import math import torch.nn as nn @@ -286,19 +285,24 @@ def pp_forward( raise ValueError( "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " "the reconstructed image has the same dimensions as the input." - f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.") + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) if not stage_manager.is_first_stage(): - assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" - - outputs = self.vit(pixel_values, - bool_masked_pos=bool_masked_pos, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - interpolate_pos_encoding=interpolate_pos_encoding, - return_dict=return_dict, - hidden_states=hidden_states) + assert ( + hidden_states is not None + ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states, + ) if not stage_manager.is_last_stage(): return outputs else: @@ -317,9 +321,12 @@ def pp_forward( if bool_masked_pos is not None: size = self.config.image_size // self.config.patch_size bool_masked_pos = bool_masked_pos.reshape(-1, size, size) - mask = (bool_masked_pos.repeat_interleave(self.config.patch_size, - 1).repeat_interleave(self.config.patch_size, - 2).unsqueeze(1).contiguous()) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels @@ -338,7 +345,6 @@ def pp_forward( def get_vit_flash_self_attention_forward(): - from transformers.models.vit.modeling_vit import ViTSelfAttention from colossalai.kernel.cuda_native import ColoAttention @@ -348,22 +354,24 @@ def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_si x = x.view(new_x_shape) return x - def forward(self: ViTSelfAttention, - hidden_states: torch.Tensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + def forward( + self: ViTSelfAttention, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: mixed_query_layer = self.query(hidden_states) key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size) - value_layer = transpose_for_scores(self.value(hidden_states), self.num_attention_heads, - self.attention_head_size) + value_layer = transpose_for_scores( + self.value(hidden_states), self.num_attention_heads, self.attention_head_size + ) query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size) scale = 1.0 / math.sqrt(self.attention_head_size) - attention = ColoAttention(embed_dim=self.all_head_size, - num_heads=self.num_attention_heads, - dropout=self.dropout.p, - scale=scale) + attention = ColoAttention( + embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale + ) context_layer = attention(query_layer, key_layer, value_layer) outputs = (context_layer,) @@ -374,7 +382,6 @@ def forward(self: ViTSelfAttention, def get_jit_fused_vit_output_forward(): - from transformers.models.vit.modeling_vit import ViTOutput def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index 62f8f7b4763e..ef59dbcee680 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -1,6 +1,6 @@ import logging import random -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -24,7 +24,6 @@ def get_whisper_flash_attention_forward(): - from transformers.models.whisper.modeling_whisper import WhisperAttention from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention @@ -53,8 +52,11 @@ def forward( # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as # the provided `key_value_states` to support prefix tuning - if (is_cross_attention and past_key_value is not None - and past_key_value[0].shape[1] == key_value_states.shape[1]): + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] @@ -89,8 +91,10 @@ def forward( src_len = key_states.size(1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): - raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}") + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) attn_type = None flash_attention_mask = None @@ -104,15 +108,12 @@ def forward( flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) attn_type = AttnMaskType.paddedcausal - attention = ColoAttention(embed_dim=self.embed_dim, - num_heads=self.num_heads, - dropout=self.dropout, - scale=self.scaling) - attn_output = attention(query_states, - key_states, - value_states, - attn_mask=flash_attention_mask, - attn_mask_type=attn_type) + attention = ColoAttention( + embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling + ) + attn_output = attention( + query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type + ) attn_output = self.out_proj(attn_output) @@ -122,7 +123,6 @@ def forward( def get_jit_fused_whisper_encoder_layer_forward(): - from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer def forward( @@ -160,8 +160,9 @@ def forward( hidden_states = self.fc2(hidden_states) hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training) - if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any() - or torch.isnan(hidden_states).any()): + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) @@ -176,7 +177,6 @@ def forward( def get_jit_fused_whisper_decoder_layer_forward(): - from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer def forward( @@ -269,10 +269,10 @@ def forward( class WhisperPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Llama models under pipeline setting. - ''' + """ @staticmethod def whisper_encoder_forward( @@ -315,15 +315,16 @@ def whisper_encoder_forward( return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - logger = logging.get_logger(__name__) + logging.get_logger(__name__) stage = stage_manager.stage - at_first_stage = (stage == 0) - at_last_stage = (stage == decoder_starting_stage - 1) + at_first_stage = stage == 0 + at_last_stage = stage == decoder_starting_stage - 1 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Process inputs if at the first stage of encoder. @@ -349,7 +350,8 @@ def whisper_encoder_forward( else: if hidden_states is None: raise ValueError( - "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." + ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -360,13 +362,12 @@ def whisper_encoder_forward( encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = random.uniform(0, 1) - if self.training and (dropout_probability < self.layerdrop): # skip the layer + if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, output_attentions) @@ -398,12 +399,12 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput(last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) else: - return {'hidden_states': hidden_states, 'head_mask': head_mask} + return {"hidden_states": hidden_states, "head_mask": head_mask} @staticmethod def whisper_decoder_forward( @@ -483,12 +484,13 @@ def whisper_decoder_forward( """ logger = logging.get_logger(__name__) stage = stage_manager.stage - at_first_stage = (stage == decoder_starting_stage) - at_last_stage = (stage == stage_manager.num_stages - 1) + at_first_stage = stage == decoder_starting_stage + at_last_stage = stage == stage_manager.num_stages - 1 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -503,7 +505,8 @@ def whisper_decoder_forward( if attn_mask is not None: assert attn_mask.size()[0] == (len(self.layers)), ( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {head_mask.size()[0]}.") + f" {head_mask.size()[0]}." + ) # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 @@ -529,8 +532,9 @@ def whisper_decoder_forward( else: positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, - past_key_values_length) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -543,14 +547,15 @@ def whisper_decoder_forward( use_cache = False else: - if hidden_states is None: raise ValueError( - "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.") + "hidden_states shouldn't be None for stages other than the first stage of encoder/decoder." + ) input_shape = hidden_states.size()[:-1] - attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states, - past_key_values_length) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, hidden_states, past_key_values_length + ) start_idx, end_idx = stage_index[0], stage_index[1] @@ -569,7 +574,6 @@ def whisper_decoder_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, use_cache) @@ -581,10 +585,10 @@ def custom_forward(*inputs): hidden_states, attention_mask, encoder_hidden_states, - None, # encoder attention mask + None, # encoder attention mask head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, # past_key_value + None, # past_key_value ) else: layer_outputs = decoder_layer( @@ -592,8 +596,9 @@ def custom_forward(*inputs): attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=(cross_attn_head_mask[idx] - if cross_attn_head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -617,8 +622,10 @@ def custom_forward(*inputs): next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( - v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] - if v is not None) + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -629,9 +636,9 @@ def custom_forward(*inputs): else: return { - 'head_mask': head_mask, - 'cross_attn_head_mask': cross_attn_head_mask, - 'hidden_states': hidden_states, + "head_mask": head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "hidden_states": hidden_states, } @staticmethod @@ -678,23 +685,24 @@ def whisper_model_forward( ```""" # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: - logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') + logger.warning_once("Non-empty past_key_values is not supported for pipeline models at the moment.") past_key_values = None if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False if use_cache: - logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - logger = logging.get_logger(__name__) + logging.get_logger(__name__) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict in_decoder = stage_manager.stage >= decoder_starting_stage @@ -712,14 +720,15 @@ def whisper_model_forward( stage_manager=stage_manager, hidden_states=hidden_states, stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_starting_stage=decoder_starting_stage, + ) if stage_manager.stage == decoder_starting_stage - 1: # last stage of encoder - return {'encoder_hidden_states': encoder_outputs[0]} + return {"encoder_hidden_states": encoder_outputs[0]} else: return encoder_outputs - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], @@ -738,27 +747,29 @@ def whisper_model_forward( raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.") # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder, - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward( + self.decoder, + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) # Directly return outputs of overloaded Whisper forward if not at last stage. if not at_last_decoder_stage: # encoder_hidden_states should be passed to the next stage - decoder_outputs['encoder_hidden_states'] = encoder_hidden_states + decoder_outputs["encoder_hidden_states"] = encoder_hidden_states return decoder_outputs if not return_dict: @@ -830,36 +841,39 @@ def whisper_for_conditional_generation_forward( if labels is not None: if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, - self.config.decoder_start_token_id) + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) in_decoder = stage_manager.stage >= decoder_starting_stage at_last_decoder_stage = stage_manager.is_last_stage() - outputs = WhisperPipelineForwards.whisper_model_forward(self.model, - input_features, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - encoder_outputs=encoder_outputs, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + outputs = WhisperPipelineForwards.whisper_model_forward( + self.model, + input_features, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) if not in_decoder: return outputs if not at_last_decoder_stage: # encoder_hidden_states should be passed to the next stage - outputs['encoder_hidden_states'] = encoder_hidden_states + outputs["encoder_hidden_states"] = encoder_hidden_states return outputs lm_logits = self.proj_out(outputs[0]) @@ -909,8 +923,9 @@ def whisper_for_audio_classification_forward( Please refer to original code of transformers for more details. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # audio_classification only holds encoder diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 49613ffb37e0..3bea91ef94dc 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -18,6 +18,7 @@ class PolicyLocation: file_name (str): The file name of the policy under colossalai.shardformer.policies class_name (str): The class name of the policy class """ + file_name: str class_name: str @@ -27,121 +28,142 @@ class PolicyLocation: # we will allow the user to only import the policy file needed _POLICY_LIST = { # BERT - "transformers.models.bert.modeling_bert.BertModel": - PolicyLocation(file_name="bert", class_name="BertModelPolicy"), - "transformers.models.bert.modeling_bert.BertForPreTraining": - PolicyLocation(file_name="bert", class_name="BertForPreTrainingPolicy"), - "transformers.models.bert.modeling_bert.BertLMHeadModel": - PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"), - "transformers.models.bert.modeling_bert.BertForMaskedLM": - PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"), - "transformers.models.bert.modeling_bert.BertForSequenceClassification": - PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"), - "transformers.models.bert.modeling_bert.BertForTokenClassification": - PolicyLocation(file_name="bert", class_name="BertForTokenClassificationPolicy"), - "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": - PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), - "transformers.models.bert.modeling_bert.BertForMultipleChoice": - PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), - "transformers.models.bert.modeling_bert.BertForQuestionAnswering": - PolicyLocation(file_name="bert", class_name="BertForQuestionAnsweringPolicy"), - + "transformers.models.bert.modeling_bert.BertModel": PolicyLocation(file_name="bert", class_name="BertModelPolicy"), + "transformers.models.bert.modeling_bert.BertForPreTraining": PolicyLocation( + file_name="bert", class_name="BertForPreTrainingPolicy" + ), + "transformers.models.bert.modeling_bert.BertLMHeadModel": PolicyLocation( + file_name="bert", class_name="BertLMHeadModelPolicy" + ), + "transformers.models.bert.modeling_bert.BertForMaskedLM": PolicyLocation( + file_name="bert", class_name="BertForMaskedLMPolicy" + ), + "transformers.models.bert.modeling_bert.BertForSequenceClassification": PolicyLocation( + file_name="bert", class_name="BertForSequenceClassificationPolicy" + ), + "transformers.models.bert.modeling_bert.BertForTokenClassification": PolicyLocation( + file_name="bert", class_name="BertForTokenClassificationPolicy" + ), + "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": PolicyLocation( + file_name="bert", class_name="BertForNextSentencePredictionPolicy" + ), + "transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation( + file_name="bert", class_name="BertForMultipleChoicePolicy" + ), + "transformers.models.bert.modeling_bert.BertForQuestionAnswering": PolicyLocation( + file_name="bert", class_name="BertForQuestionAnsweringPolicy" + ), # LLaMA - "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaModelPolicy"), - "transformers.models.llama.modeling_llama.LlamaForCausalLM": - PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"), - "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": - PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"), - + "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation( + file_name="llama", class_name="LlamaModelPolicy" + ), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation( + file_name="llama", class_name="LlamaForCausalLMPolicy" + ), + "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": PolicyLocation( + file_name="llama", class_name="LlamaForSequenceClassificationPolicy" + ), # T5 - "transformers.models.t5.modeling_t5.T5Model": - PolicyLocation(file_name="t5", class_name="T5ModelPolicy"), - "transformers.models.t5.modeling_t5.T5ForConditionalGeneration": - PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"), - "transformers.models.t5.modeling_t5.T5EncoderModel": - PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), - + "transformers.models.t5.modeling_t5.T5Model": PolicyLocation(file_name="t5", class_name="T5ModelPolicy"), + "transformers.models.t5.modeling_t5.T5ForConditionalGeneration": PolicyLocation( + file_name="t5", class_name="T5ForConditionalGenerationPolicy" + ), + "transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), # GPT2 - "transformers.models.gpt2.modeling_gpt2.GPT2Model": - PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": - PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": - PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": - PolicyLocation(file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": - PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), - "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": - PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), - + "transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation( + file_name="gpt2", class_name="GPT2LMHeadModelPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": PolicyLocation( + file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2ForQuestionAnswering": PolicyLocation( + file_name="gpt2", class_name="GPT2ForQuestionAnsweringPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": PolicyLocation( + file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy" + ), + "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation( + file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy" + ), # ViT - "transformers.models.vit.modeling_vit.ViTModel": - PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), - "transformers.models.vit.modeling_vit.ViTForImageClassification": - PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"), - "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": - PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"), - + "transformers.models.vit.modeling_vit.ViTModel": PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), + "transformers.models.vit.modeling_vit.ViTForImageClassification": PolicyLocation( + file_name="vit", class_name="ViTForImageClassificationPolicy" + ), + "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": PolicyLocation( + file_name="vit", class_name="ViTForMaskedImageModelingPolicy" + ), # OPT - "transformers.models.opt.modeling_opt.OPTModel": - PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), - "transformers.models.opt.modeling_opt.OPTForCausalLM": - PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"), - "transformers.models.opt.modeling_opt.OPTForSequenceClassification": - PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), - "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": - PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"), - + "transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), + "transformers.models.opt.modeling_opt.OPTForCausalLM": PolicyLocation( + file_name="opt", class_name="OPTForCausalLMPolicy" + ), + "transformers.models.opt.modeling_opt.OPTForSequenceClassification": PolicyLocation( + file_name="opt", class_name="OPTForSequenceClassificationPolicy" + ), + "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": PolicyLocation( + file_name="opt", class_name="OPTForQuestionAnsweringPolicy" + ), # Bloom - "transformers.models.bloom.modeling_bloom.BloomModel": - PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForCausalLM": - PolicyLocation(file_name="bloom", class_name="BloomForCausalLMPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": - PolicyLocation(file_name="bloom", class_name="BloomForSequenceClassificationPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForTokenClassification": - PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": - PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"), - + "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation( + file_name="bloom", class_name="BloomModelPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation( + file_name="bloom", class_name="BloomForCausalLMPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": PolicyLocation( + file_name="bloom", class_name="BloomForSequenceClassificationPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForTokenClassification": PolicyLocation( + file_name="bloom", class_name="BloomForTokenClassificationPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": PolicyLocation( + file_name="bloom", class_name="BloomForQuestionAnsweringPolicy" + ), # Whisper - "transformers.models.whisper.modeling_whisper.WhisperModel": - PolicyLocation(file_name="whisper", class_name="WhisperModelPolicy"), - "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": - PolicyLocation(file_name="whisper", class_name="WhisperForConditionalGenerationPolicy"), - "transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": - PolicyLocation(file_name="whisper", class_name="WhisperForAudioClassificationPolicy"), - + "transformers.models.whisper.modeling_whisper.WhisperModel": PolicyLocation( + file_name="whisper", class_name="WhisperModelPolicy" + ), + "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration": PolicyLocation( + file_name="whisper", class_name="WhisperForConditionalGenerationPolicy" + ), + "transformers.models.whisper.modeling_whisper.WhisperForAudioClassification": PolicyLocation( + file_name="whisper", class_name="WhisperForAudioClassificationPolicy" + ), # Sam - "transformers.models.sam.modeling_sam.SamModel": - PolicyLocation(file_name="sam", class_name="SamModelPolicy"), - + "transformers.models.sam.modeling_sam.SamModel": PolicyLocation(file_name="sam", class_name="SamModelPolicy"), # Blip2 - "transformers.models.blip_2.modeling_blip_2.Blip2Model": - PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"), - "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": - PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"), - + "transformers.models.blip_2.modeling_blip_2.Blip2Model": PolicyLocation( + file_name="blip2", class_name="Blip2ModelPolicy" + ), + "transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration": PolicyLocation( + file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy" + ), # ChatGLM - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": - PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"), - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": - PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( + file_name="chatglm2", class_name="ChatGLMModelPolicy" + ), + "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( + file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" + ), } _INFER_POLICY_LIST = { # LlaMa - "transformers.models.llama.modeling_llama.LlamaModel": - PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), - "transformers.models.llama.modeling_llama.LlamaForCausalLM": - PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), + "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation( + file_name="llama", class_name="LlamaModelInferPolicy" + ), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation( + file_name="llama", class_name="LlamaModelInferPolicy" + ), # Bloom - "transformers.models.bloom.modeling_bloom.BloomModel": - PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), - "transformers.models.bloom.modeling_bloom.BloomForCausalLM": - PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), + "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation( + file_name="bloom", class_name="BloomModelInferPolicy" + ), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation( + file_name="bloom", class_name="BloomModelInferPolicy" + ), } @@ -163,9 +185,9 @@ def _fullname(obj): """ klass = obj.__class__ module = klass.__module__ - if module == 'builtins': - return klass.__qualname__ # avoid outputs like 'builtins.str' - return module + '.' + klass.__qualname__ + if module == "builtins": + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + klass.__qualname__ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy: diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 961c6a5259fe..e7f199129a00 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -106,14 +106,12 @@ def config_sanity_check(self): This method is made abstractmethod with no default implementation because we want to the policy writer to take note of the feature supported by his/her model and policy. """ - pass @abstractmethod def preprocess(self) -> nn.Module: r""" Perform some preprocessing of the model, like reshaping the embedding layer. """ - pass @abstractmethod def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -122,7 +120,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: and the value is the ModulePolicyDescription object. The ModulePolicyDescription object describes how the module will be transformed. """ - pass @abstractmethod def postprocess(self) -> nn.Module: @@ -130,13 +127,13 @@ def postprocess(self) -> nn.Module: Perform some postprocessing of the model, like binding the weight of embedding layer with the classifier layer """ - pass def append_or_create_submodule_replacement( - self, description: Union[SubModuleReplacementDescription, - List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module], - ModulePolicyDescription], - target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + self, + description: Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]], + policy: Dict[Union[str, nn.Module], ModulePolicyDescription], + target_key: Union[str, nn.Module], + ) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: r""" Append or create a new submodule replacement description to the policy for the given key. @@ -161,8 +158,11 @@ def append_or_create_submodule_replacement( return policy def append_or_create_method_replacement( - self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription], - target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + self, + description: Dict[str, Callable], + policy: Dict[Union[str, nn.Module], ModulePolicyDescription], + target_key: Union[str, nn.Module], + ) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: r""" Append or create a new method replacement description to the policy for the given key. @@ -199,9 +199,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: @staticmethod def distribute_layers(num_layers: int, num_stages: int) -> List[int]: - """Divide layers into stages - - """ + """Divide layers into stages""" quotient = num_layers // num_stages remainder = num_layers % num_stages diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index a141b7bd8fdf..14146de158ae 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -7,7 +7,6 @@ import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.bert import ( BertPipelineForwards, bert_sequence_parallel_forward_fn, @@ -19,14 +18,20 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy', - 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', - 'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy' + "BertPolicy", + "BertModelPolicy", + "BertForPreTrainingPolicy", + "BertLMdHeadModelPolicy", + "BertForMaskedLMPolicy", + "BertForNextSentencePredictionPolicy", + "BertForSequenceClassificationPolicy", + "BertForTokenClassificationPolicy", + "BertForMultipleChoicePolicy", + "BertForQuestionAnsweringPolicy", ] class BertPolicy(Policy): - def config_sanity_check(self): pass @@ -58,136 +63,140 @@ def module_policy(self): use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: - policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ - "attention.self.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "crossattention.self.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attention.self.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "crossattention.self.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.self.query", - target_module=col_nn.Linear1D_Col, - kwargs={ - "seq_parallel": use_sequence_parallel, - "overlap": overlap - }, - ), - SubModuleReplacementDescription( - suffix="attention.self.key", - target_module=col_nn.Linear1D_Col, - kwargs={ - "seq_parallel": use_sequence_parallel, - "overlap": overlap - }, - ), - SubModuleReplacementDescription( - suffix="attention.self.value", - target_module=col_nn.Linear1D_Col, - kwargs={ - "seq_parallel": use_sequence_parallel, - "overlap": overlap - }, - ), - SubModuleReplacementDescription( - suffix="attention.self.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - kwargs={ - "seq_parallel": use_sequence_parallel, - "overlap": overlap - }, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - kwargs={"seq_parallel": use_sequence_parallel}, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) - - policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ) - ]) + policy[BertLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.self.all_head_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "crossattention.self.all_head_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "attention.self.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "crossattention.self.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) + + policy[BertEmbeddings] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ] + ) if use_sequence_parallel: self.append_or_create_method_replacement( - description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)}, + description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, - target_key=BertModel) + target_key=BertModel, + ) # optimization configuration if self.shard_config.enable_fused_normalization: # Handle bert layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="attention.output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=BertLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=BertLayer, + ) # handle embedding layer self.append_or_create_submodule_replacement( - description=[SubModuleReplacementDescription( - suffix="LayerNorm", - target_module=col_nn.FusedLayerNorm, - )], + description=[ + SubModuleReplacementDescription( + suffix="LayerNorm", + target_module=col_nn.FusedLayerNorm, + ) + ], policy=policy, - target_key=BertEmbeddings) + target_key=BertEmbeddings, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_bert_flash_attention_forward(), - }, - policy=policy, - target_key=BertSelfAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_bert_flash_attention_forward(), + }, + policy=policy, + target_key=BertSelfAttention, + ) # use jit operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bert_self_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=BertSelfOutput) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bert_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=BertOutput) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_self_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BertSelfOutput, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bert_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BertOutput, + ) return policy @@ -196,31 +205,37 @@ def add_lm_head_policy(self, base_policy): # optimize for tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - policy=base_policy, - target_key=BertLMPredictionHead) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ), + policy=base_policy, + target_key=BertLMPredictionHead, + ) # optimize with fused normalization if self.shard_config.enable_fused_normalization: # Handle bert lm prediction head - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="transform.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - policy=base_policy, - target_key=BertLMPredictionHead) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + policy=base_policy, + target_key=BertLMPredictionHead, + ) return base_policy def add_lm_prediction_policy(self, base_policy): from transformers.models.bert.modeling_bert import BertLMPredictionHead + method_replacement = { - '_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict, - '_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict, + "_save_to_state_dict": col_nn.ParallelModule._save_to_state_dict, + "_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict, } - self.append_or_create_method_replacement(description=method_replacement, - policy=base_policy, - target_key=BertLMPredictionHead) + self.append_or_create_method_replacement( + description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead + ) return base_policy def postprocess(self): @@ -228,7 +243,7 @@ def postprocess(self): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "BertModel": @@ -239,15 +254,13 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config) + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) return @@ -255,7 +268,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'BertModel': + if self.model.__class__.__name__ == "BertModel": module = self.model else: module = self.model.bert @@ -275,17 +288,17 @@ def get_held_layers(self) -> List[Module]: # BertModel class BertModelPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertModel + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertModel, - new_forward=BertPipelineForwards.bert_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -300,7 +313,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # BertForPreTraining class BertForPreTrainingPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -309,10 +321,13 @@ def module_policy(self): policy = self.add_lm_head_policy(policy) policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForPreTraining + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForPreTraining, - new_forward=BertPipelineForwards.bert_for_pretraining_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForPreTraining, + new_forward=BertPipelineForwards.bert_for_pretraining_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[Module]: @@ -329,16 +344,17 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight): # tie weights - return [{ - 0: model.bert.embeddings.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight - }] + return [ + { + 0: model.bert.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight, + } + ] return [] # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -347,10 +363,11 @@ def module_policy(self): policy = self.add_lm_head_policy(policy) policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertLMHeadModel + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertLMHeadModel, - new_forward=BertPipelineForwards.bert_lm_head_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -368,16 +385,17 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): # tie weights - return [{ - 0: bert_model.embeddings.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight - }] + return [ + { + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight, + } + ] return [] # BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -386,10 +404,11 @@ def module_policy(self): policy = self.add_lm_head_policy(policy) policy = self.add_lm_prediction_policy(policy) from transformers.models.bert.modeling_bert import BertForMaskedLM + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForMaskedLM, - new_forward=BertPipelineForwards.bert_for_masked_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -407,16 +426,17 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight): # tie weights - return [{ - 0: bert_model.embeddings.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight - }] + return [ + { + 0: bert_model.embeddings.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight, + } + ] return [] # BertForSequenceClassification class BertForSequenceClassificationPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -427,19 +447,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - BertForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ + BertForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", target_module=col_nn.DropoutForParallelInput, ) - ]) + ] + ) } policy.update(addon_module) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForSequenceClassification, - new_forward=BertPipelineForwards.bert_for_sequence_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForSequenceClassification, + new_forward=BertPipelineForwards.bert_for_sequence_classification_forward, + policy=policy, + ) return policy @@ -461,7 +484,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # BertForTokenClassification class BertForTokenClassificationPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -472,19 +494,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - BertForTokenClassification: - ModulePolicyDescription(sub_module_replacement=[ + BertForTokenClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", target_module=col_nn.DropoutForParallelInput, ) - ]) + ] + ) } policy.update(addon_module) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForTokenClassification, - new_forward=BertPipelineForwards.bert_for_token_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForTokenClassification, + new_forward=BertPipelineForwards.bert_for_token_classification_forward, + policy=policy, + ) return policy @@ -506,17 +531,19 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # BertForNextSentencePrediction class BertForNextSentencePredictionPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() from transformers.models.bert.modeling_bert import BertForNextSentencePrediction + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForNextSentencePrediction, - new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForNextSentencePrediction, + new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward, + policy=policy, + ) return policy @@ -537,7 +564,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # BertForMultipleChoice class BertForMultipleChoicePolicy(BertPolicy): - def __init__(self) -> None: super().__init__() @@ -548,19 +574,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - BertForMultipleChoice: - ModulePolicyDescription(sub_module_replacement=[ + BertForMultipleChoice: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", target_module=col_nn.DropoutForParallelInput, ) - ]) + ] + ) } policy.update(addon_module) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForMultipleChoice, - new_forward=BertPipelineForwards.bert_for_multiple_choice_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForMultipleChoice, + new_forward=BertPipelineForwards.bert_for_multiple_choice_forward, + policy=policy, + ) return policy @@ -581,17 +610,19 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class BertForQuestionAnsweringPolicy(BertPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers.models.bert.modeling_bert import BertForQuestionAnswering + policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BertForQuestionAnswering, - new_forward=BertPipelineForwards.bert_for_question_answering_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BertForQuestionAnswering, + new_forward=BertPipelineForwards.bert_for_question_answering_forward, + policy=policy, + ) return policy diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 2e5388ab0490..997643d1a911 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -1,8 +1,5 @@ -import torch.nn as nn - import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.blip2 import ( forward_fn, get_blip2_flash_attention_forward, @@ -12,11 +9,10 @@ from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['BlipPolicy', 'BlipModelPolicy'] +__all__ = ["BlipPolicy", "BlipModelPolicy"] class BlipPolicy(Policy): - def config_sanity_check(self): pass @@ -48,263 +44,293 @@ def module_policy(self): policy = {} if self.shard_config.enable_tensor_parallelism: - policy[Blip2EncoderLayer] = ModulePolicyDescription(attribute_replacement={ - "self_attn.num_heads": - self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attn.embed_dim": - self.model.config.vision_config.hidden_size // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="self_attn.qkv", - target_module=col_nn.FusedLinear1D_Col, - kwargs={ - "n_fused": 3, - }), - SubModuleReplacementDescription( - suffix="self_attn.projection", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.fc1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.fc2", - target_module=col_nn.Linear1D_Row, - ), - ]) + policy[Blip2EncoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.num_heads": self.model.config.vision_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attn.embed_dim": self.model.config.vision_config.hidden_size + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="self_attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="self_attn.projection", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) - policy[Blip2QFormerModel] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[Blip2QFormerModel] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) - policy[Blip2QFormerLayer] = ModulePolicyDescription(attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size, - "crossattention.attention.num_attention_heads": - self.model.config.qformer_config.num_attention_heads // self.shard_config.tensor_parallel_size, - "crossattention.attention.all_head_size": - self.model.config.qformer_config.hidden_size // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="crossattention.attention.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="crossattention.attention.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="crossattention.attention.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="crossattention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="crossattention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="crossattention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate_query.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output_query.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output_query.dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) + policy[Blip2QFormerLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": self.model.config.qformer_config.hidden_size + // self.shard_config.tensor_parallel_size, + "crossattention.attention.num_attention_heads": self.model.config.qformer_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "crossattention.attention.all_head_size": self.model.config.qformer_config.hidden_size + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="crossattention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate_query.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output_query.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output_query.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ], + ) - policy[OPTDecoderLayer] = ModulePolicyDescription(attribute_replacement={ - "self_attn.embed_dim": - self.model.config.text_config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.text_config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.out_proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="fc1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=col_nn.Linear1D_Row, - ) - ]) + policy[OPTDecoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.embed_dim": self.model.config.text_config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.text_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) - policy[OPTForCausalLM] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="model.decoder.embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}, - ), - ]) + policy[OPTForCausalLM] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="model.decoder.embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}, + ), + ] + ) policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()}) # optimization configuration if self.shard_config.enable_fused_normalization: # Handle Blip2EncoderLayer layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm2", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=Blip2EncoderLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=Blip2EncoderLayer, + ) # handle Blip2VisionModel layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="post_layernorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=Blip2VisionModel) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="post_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=Blip2VisionModel, + ) # handle Blip2VisionModel layer self.append_or_create_submodule_replacement( - description=[SubModuleReplacementDescription( - suffix="layernorm", - target_module=col_nn.FusedLayerNorm, - )], + description=[ + SubModuleReplacementDescription( + suffix="layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], policy=policy, - target_key=Blip2QFormerModel) + target_key=Blip2QFormerModel, + ) # handle Blip2QFormerLayer layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="attention.output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="crossattention.output.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="output_query.LayerNorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=Blip2QFormerLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="crossattention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="output_query.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=Blip2QFormerLayer, + ) # handle OPTForCausalLM layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="model.decoder.final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=OPTForCausalLM) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="model.decoder.final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=OPTForCausalLM, + ) # handle OPTDecoderLayer layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=OPTDecoderLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=OPTDecoderLayer, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_blip2_flash_attention_forward(), - }, - policy=policy, - target_key=Blip2Attention) + self.append_or_create_method_replacement( + description={ + "forward": get_blip2_flash_attention_forward(), + }, + policy=policy, + target_key=Blip2Attention, + ) # use jit operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_blip2_QFormer_self_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=Blip2QFormerSelfOutput) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_blip2_QFormer_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=Blip2QFormerOutput) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_QFormer_self_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerSelfOutput, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_blip2_QFormer_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=Blip2QFormerOutput, + ) return policy @@ -314,13 +340,11 @@ def postprocess(self): # Blip2Model class Blip2ModelPolicy(BlipPolicy): - def __init__(self) -> None: super().__init__() # Blip2ForConditionalGeneration class Blip2ForConditionalGenerationPolicy(BlipPolicy): - def __init__(self) -> None: super().__init__() diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 7c418d02bcb6..13b9dd31345d 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List import torch.nn as nn from torch import Tensor @@ -7,7 +7,6 @@ import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.bloom import ( BloomPipelineForwards, build_bloom_alibi_tensor_fn, @@ -22,7 +21,6 @@ class BloomPolicy(Policy): - def config_sanity_check(self): pass @@ -47,39 +45,41 @@ def module_policy(self): use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: - policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - kwargs={ - 'seq_parallel': use_sequence_parallel, - 'overlap': overlap - }), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - kwargs={'seq_parallel': use_sequence_parallel}), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=col_nn.Linear1D_Col, - kwargs={ - 'seq_parallel': use_sequence_parallel, - 'overlap': overlap - }), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=col_nn.Linear1D_Row, - kwargs={'seq_parallel': use_sequence_parallel}), - ]) + policy[BloomBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel}, + ), + ], + ) policy[BloomModel] = ModulePolicyDescription( attribute_replacement={ @@ -93,72 +93,86 @@ def module_policy(self): suffix="word_embeddings", target_module=col_nn.VocabParallelEmbedding1D, ) - ]) + ], + ) # optimization configuration if self.shard_config.enable_fused_normalization: # handle bloom model - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="ln_f", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="word_embeddings_layernorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=BloomModel) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="word_embeddings_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=BloomModel, + ) # handle bloom block - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=BloomBlock) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=BloomBlock, + ) if use_sequence_parallel: self.append_or_create_method_replacement( - description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)}, + description={"forward": get_bloom_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, - target_key=BloomModel) + target_key=BloomModel, + ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_bloom_flash_attention_forward(), - 'dropout_add': get_dropout_add_func(), - }, - policy=policy, - target_key=BloomAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_bloom_flash_attention_forward(), + "dropout_add": get_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention, + ) # enable jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bloom_attention_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=BloomAttention) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bloom_mlp_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=BloomMLP) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_bloom_gelu_forward(), - 'bloom_gelu_forward': get_jit_fused_gelu_forward_func(), - }, - policy=policy, - target_key=BloomGelu) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bloom_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BloomAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bloom_mlp_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=BloomMLP, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_bloom_gelu_forward(), + "bloom_gelu_forward": get_jit_fused_gelu_forward_func(), + }, + policy=policy, + target_key=BloomGelu, + ) return policy @@ -167,7 +181,7 @@ def postprocess(self): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "BloomModel": @@ -178,22 +192,20 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config) + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) return def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'BloomModel': + if self.model.__class__.__name__ == "BloomModel": module = self.model else: module = self.model.transformer @@ -213,17 +225,17 @@ def get_held_layers(self) -> List[Module]: class BloomModelPolicy(BloomPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() from transformers.models.bloom.modeling_bloom import BloomModel + if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomModel, - new_forward=BloomPipelineForwards.bloom_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomModel, new_forward=BloomPipelineForwards.bloom_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -234,26 +246,29 @@ def get_held_layers(self) -> List[Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''no shared params in bloom model''' + """no shared params in bloom model""" return [] class BloomForCausalLMPolicy(BloomPolicy): - def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForCausalLM + policy = super().module_policy() # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=BloomForCausalLM) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=BloomForCausalLM, + ) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomForCausalLM, - new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomForCausalLM, new_forward=BloomPipelineForwards.bloom_for_causal_lm_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -269,29 +284,36 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if id(bloom_model.transformer.word_embeddings.weight) == id(bloom_model.lm_head.weight): # tie weights - return [{ - 0: bloom_model.transformer.word_embeddings.weight, - self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight - }] + return [ + { + 0: bloom_model.transformer.word_embeddings.weight, + self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight, + } + ] return [] class BloomForSequenceClassificationPolicy(BloomPolicy): - def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification + policy = super().module_policy() # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=BloomForSequenceClassification) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=BloomForSequenceClassification, + ) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomForSequenceClassification, - new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomForSequenceClassification, + new_forward=BloomPipelineForwards.bloom_for_sequence_classification_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[Module]: @@ -308,28 +330,32 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class BloomForTokenClassificationPolicy(BloomPolicy): - def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForTokenClassification + policy = super().module_policy() # handle tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="classifier", - target_module=col_nn.Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - ], - policy=policy, - target_key=BloomForTokenClassification) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=BloomForTokenClassification, + ) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomForTokenClassification, - new_forward=BloomPipelineForwards.bloom_for_token_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomForTokenClassification, + new_forward=BloomPipelineForwards.bloom_for_token_classification_forward, + policy=policy, + ) return policy @@ -351,11 +377,14 @@ class BloomForQuestionAnsweringPolicy(BloomPolicy): # No head sharding as the output features is only 2 def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForQuestionAnswering + policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=BloomForQuestionAnswering, - new_forward=BloomPipelineForwards.bloom_for_question_answering_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=BloomForQuestionAnswering, + new_forward=BloomPipelineForwards.bloom_for_question_answering_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[Module]: diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 44898847056a..3c27c848e738 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -1,19 +1,12 @@ from functools import partial -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Union import torch.nn as nn from torch import Tensor -from transformers.modeling_outputs import BaseModelOutputWithPast import colossalai.shardformer.layer as col_nn -from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards -from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( - ChatGLMForConditionalGeneration, - ChatGLMModel, - GLMBlock, -) +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel from ..modeling.chatglm2 import ( get_chatglm_sequence_parallel_forward_fn, @@ -23,11 +16,10 @@ from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy'] +__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"] class ChatGLMPolicy(Policy): - def config_sanity_check(self): pass @@ -44,12 +36,11 @@ def preprocess(self): if self.pipeline_stage_manager is not None: # the batch_size_dim is bounded to Model bsz_dim = 1 - setattr(self.model, 'batch_size_dim', bsz_dim) + setattr(self.model, "batch_size_dim", bsz_dim) return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock policy = {} @@ -57,111 +48,129 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: - policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={}, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embedding.word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) - ]) + policy[ChatGLMModel] = ModulePolicyDescription( + attribute_replacement={}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embedding.word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ], + ) policy[GLMBlock] = ModulePolicyDescription( attribute_replacement={ - "self_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.projection_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads) // - self.shard_config.tensor_parallel_size, - "self_attention.qkv_hidden_size": - (self.model.config.kv_channels * self.model.config.num_attention_heads * 3) // - self.shard_config.tensor_parallel_size, - "self_attention.core_attention.num_attention_heads_per_partition": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "self_attention.core_attention.hidden_size_per_partition": - self.model.config.kv_channels * self.model.config.num_attention_heads // - self.shard_config.tensor_parallel_size, + "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attention.projection_size": ( + self.model.config.kv_channels * self.model.config.num_attention_heads + ) + // self.shard_config.tensor_parallel_size, + "self_attention.qkv_hidden_size": ( + self.model.config.kv_channels * self.model.config.num_attention_heads * 3 + ) + // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels + * self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, }, param_replacement=[], sub_module_replacement=[ - SubModuleReplacementDescription(suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - kwargs={ - 'seq_parallel': use_sequence_parallel, - 'seq_parallel_dim': 0, - 'overlap': overlap - }), - SubModuleReplacementDescription(suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - kwargs={ - 'seq_parallel': use_sequence_parallel, - 'seq_parallel_dim': 0 - }), + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + kwargs={"seq_parallel": use_sequence_parallel, "seq_parallel_dim": 0}, + ), SubModuleReplacementDescription( suffix="self_attention.core_attention.attention_dropout", target_module=col_nn.DropoutForParallelInput, ), - ]) + ], + ) # optimization configuration if self.shard_config.enable_fused_normalization: if not self.model.config.rmsnorm: - - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), - SubModuleReplacementDescription(suffix="post_attention_layernorm", - target_module=col_nn.FusedLayerNorm) - ], - policy=policy, - target_key=GLMBlock) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm + ), + ], + policy=policy, + target_key=GLMBlock, + ) if self.model.config.post_layer_norm: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="encoder.final_layernorm", - target_module=col_nn.FusedLayerNorm) - ], - policy=policy, - target_key=ChatGLMModel) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm + ) + ], + policy=policy, + target_key=ChatGLMModel, + ) else: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), - SubModuleReplacementDescription(suffix="post_attention_layernorm", - target_module=col_nn.FusedRMSNorm) - ], - policy=policy, - target_key=GLMBlock) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm + ), + ], + policy=policy, + target_key=GLMBlock, + ) if self.model.config.post_layer_norm: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="encoder.final_layernorm", - target_module=col_nn.FusedRMSNorm) - ], - policy=policy, - target_key=ChatGLMModel) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm + ) + ], + policy=policy, + target_key=ChatGLMModel, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_flash_core_attention_forward(), - }, - policy=policy, - target_key=CoreAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_flash_core_attention_forward(), + }, + policy=policy, + target_key=CoreAttention, + ) # use sequence parallel if use_sequence_parallel: self.append_or_create_method_replacement( - description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, + description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, - target_key=ChatGLMModel) + target_key=ChatGLMModel, + ) # use jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_glm_block_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=GLMBlock) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_glm_block_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=GLMBlock, + ) return policy @@ -172,7 +181,7 @@ def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'ChatGLMModel': + if self.model.__class__.__name__ == "ChatGLMModel": module = self.model else: module = self.model.transformer @@ -195,11 +204,11 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'ChatGLMModel': + if self.model.__class__.__name__ == "ChatGLMModel": module = self.model else: module = self.model.transformer @@ -207,29 +216,26 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config) + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) class ChatGLMModelPolicy(ChatGLMPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2Model + pass policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=ChatGLMModel, - new_forward=ChatGLMPipelineForwards.chatglm_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[nn.Module]: @@ -241,14 +247,15 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): - def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=ChatGLMForConditionalGeneration, - new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=ChatGLMForConditionalGeneration, + new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 5093fd469af8..6f46bfc7ef9f 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -5,18 +5,20 @@ import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', - 'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy' + "GPT2Policy", + "GPT2ModelPolicy", + "GPT2LMHeadModelPolicy", + "GPT2DoubleHeadsModelPolicy", + "GPT2ForTokenClassificationPolicy", + "GPT2ForSequenceClassificationPolicy", ] class GPT2Policy(Policy): - def config_sanity_check(self): pass @@ -40,16 +42,18 @@ def module_policy(self): use_sequence_parallel = self.shard_config.enable_sequence_parallelism overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: - policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="drop", - target_module=col_nn.DropoutForParallelInput, - ), - ]) + policy[GPT2Model] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="drop", + target_module=col_nn.DropoutForParallelInput, + ), + ] + ) policy[GPT2Block] = ModulePolicyDescription( attribute_replacement={ @@ -61,31 +65,27 @@ def module_policy(self): SubModuleReplacementDescription( suffix="attn.c_attn", target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, kwargs={ - "n_fused": 3, "seq_parallel": use_sequence_parallel, - "overlap": overlap }, ), - SubModuleReplacementDescription(suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }), SubModuleReplacementDescription( suffix="mlp.c_fc", target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, kwargs={ - "n_fused": 1, "seq_parallel": use_sequence_parallel, - "overlap": overlap }, ), - SubModuleReplacementDescription(suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - kwargs={ - "seq_parallel": use_sequence_parallel, - }), SubModuleReplacementDescription( suffix="attn.attn_dropout", target_module=col_nn.DropoutForParallelInput, @@ -98,39 +98,46 @@ def module_policy(self): suffix="mlp.dropout", target_module=col_nn.DropoutForParallelInput, ), - ]) + ], + ) # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="ln_f", - target_module=col_nn.FusedLayerNorm, - ), - policy=policy, - target_key=GPT2Model) - - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="ln_1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="ln_2", + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="ln_f", target_module=col_nn.FusedLayerNorm, ), - SubModuleReplacementDescription(suffix="ln_cross_attn", - target_module=col_nn.FusedLayerNorm, - ignore_if_not_exist=True) - ], - policy=policy, - target_key=GPT2Block) + policy=policy, + target_key=GPT2Model, + ) + + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="ln_2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True + ), + ], + policy=policy, + target_key=GPT2Block, + ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_gpt2_flash_attention_forward(), - }, - policy=policy, - target_key=GPT2Attention) + self.append_or_create_method_replacement( + description={ + "forward": get_gpt2_flash_attention_forward(), + }, + policy=policy, + target_key=GPT2Attention, + ) if self.shard_config.enable_sequence_parallelism: policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} @@ -144,7 +151,7 @@ def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'GPT2Model': + if self.model.__class__.__name__ == "GPT2Model": module = self.model else: module = self.model.transformer @@ -164,11 +171,11 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'GPT2Model': + if self.model.__class__.__name__ == "GPT2Model": module = self.model else: module = self.model.transformer @@ -176,18 +183,15 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - shard_config=self.shard_config) + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) # GPT2Model class GPT2ModelPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -197,9 +201,9 @@ def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2Model, - new_forward=GPT2PipelineForwards.gpt2_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[nn.Module]: @@ -212,7 +216,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # GPT2LMHeadModel class GPT2LMHeadModelPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -223,18 +226,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - GPT2LMHeadModel: - ModulePolicyDescription(sub_module_replacement=[ + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy, + ) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -244,7 +251,7 @@ def get_held_layers(self) -> List[nn.Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''The weights of wte and lm_head are shared.''' + """The weights of wte and lm_head are shared.""" module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None: @@ -256,7 +263,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # GPT2DoubleHeadsModel class GPT2DoubleHeadsModelPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -267,18 +273,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - GPT2DoubleHeadsModel: - ModulePolicyDescription(sub_module_replacement=[ + GPT2DoubleHeadsModel: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2DoubleHeadsModel, - new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2DoubleHeadsModel, + new_forward=GPT2PipelineForwards.gpt2_double_heads_model_forward, + policy=module_policy, + ) return module_policy @@ -295,7 +305,7 @@ def get_held_layers(self) -> List[nn.Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''The weights of wte and lm_head are shared.''' + """The weights of wte and lm_head are shared.""" module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None: @@ -307,7 +317,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # GPT2ForQuestionAnswering class GPT2ForQuestionAnsweringPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -317,9 +326,11 @@ def module_policy(self): module_policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2ForQuestionAnswering, - new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2ForQuestionAnswering, + new_forward=GPT2PipelineForwards.gpt2_for_question_answering_forward, + policy=module_policy, + ) return module_policy @@ -330,13 +341,12 @@ def get_held_layers(self) -> List[nn.Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''No shared_params in gpt2 for QA.''' + """No shared_params in gpt2 for QA.""" return [] # GPT2ForTokenClassification class GPT2ForTokenClassificationPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -347,17 +357,20 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - GPT2ForTokenClassification: - ModulePolicyDescription(sub_module_replacement=[ + GPT2ForTokenClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) - ]) + ] + ) } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2ForTokenClassification, - new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2ForTokenClassification, + new_forward=GPT2PipelineForwards.gpt2_for_token_classification_forward, + policy=module_policy, + ) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -374,7 +387,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # GPT2ForSequenceClassification class GPT2ForSequenceClassificationPolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -384,9 +396,11 @@ def module_policy(self): module_policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2ForSequenceClassification, - new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2ForSequenceClassification, + new_forward=GPT2PipelineForwards.gpt2_for_sequence_classification_forward, + policy=module_policy, + ) return module_policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index cc131e8168fc..099995acb440 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -11,11 +11,10 @@ from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] +__all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] class LlamaPolicy(Policy): - def config_sanity_check(self): pass @@ -40,15 +39,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: self.shard_config.enable_sequence_parallelism = False warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") - if self.shard_config.enable_tensor_parallelism: decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = \ + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size + ) policy[LlamaDecoderLayer] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, @@ -80,45 +79,53 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: SubModuleReplacementDescription( suffix="mlp.down_proj", target_module=Linear1D_Row, - ) + ), ], ) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ), - policy=policy, - target_key=LlamaModel) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=LlamaModel, + ) # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="input_layernorm", + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key=LlamaDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", target_module=FusedRMSNorm, ), - SubModuleReplacementDescription( - suffix="post_attention_layernorm", - target_module=FusedRMSNorm, - ) - ], - policy=policy, - target_key=LlamaDecoderLayer) - - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=LlamaModel) + policy=policy, + target_key=LlamaModel, + ) if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_llama_flash_attention_forward(), - }, - policy=policy, - target_key=LlamaAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_forward(), + }, + policy=policy, + target_key=LlamaAttention, + ) return policy @@ -127,7 +134,7 @@ def postprocess(self): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager if self.model.__class__.__name__ == "LlamaModel": @@ -137,10 +144,10 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) return @@ -148,7 +155,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'LlamaModel': + if self.model.__class__.__name__ == "LlamaModel": module = self.model else: module = self.model.model @@ -167,18 +174,18 @@ def get_held_layers(self) -> List[Module]: class LlamaModelPolicy(LlamaPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): policy = super().module_policy() from transformers.models.llama.modeling_llama import LlamaModel + if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaModel, - new_forward=LlamaPipelineForwards.llama_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaModel, new_forward=LlamaPipelineForwards.llama_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[Module]: @@ -192,7 +199,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class LlamaForCausalLMPolicy(LlamaPolicy): - def module_policy(self): from transformers import LlamaForCausalLM @@ -201,19 +207,21 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { - LlamaForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) } policy.update(new_item) if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaForCausalLM, - new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) return policy @@ -228,18 +236,21 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: - if id(llama_model.embed_tokens.weight) == id( - self.model.lm_head.weight) and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): # tie weights - return [{ - 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight - }] + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] return [] class LlamaForSequenceClassificationPolicy(LlamaPolicy): - def module_policy(self): from transformers import LlamaForSequenceClassification @@ -248,19 +259,23 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: # add a new item for sequence classification new_item = { - LlamaForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ + LlamaForSequenceClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) } policy.update(new_item) # to be confirmed if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaForSequenceClassification, - new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaForSequenceClassification, + new_forward=LlamaPipelineForwards.llama_for_sequence_classification_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[Module]: diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index abe491bfaace..5739d21a3903 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -13,13 +13,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy', - 'OPTForQuestionAnsweringPolicy' + "OPTPolicy", + "OPTModelPolicy", + "OPTForCausalLMPolicy", + "OPTForSequenceClassificationPolicy", + "OPTForQuestionAnsweringPolicy", ] class OPTPolicy(Policy): - def config_sanity_check(self): pass @@ -45,79 +47,94 @@ def module_policy(self): warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: - policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) - policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]) - - policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]) + policy[OPTDecoder] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ] + ) + policy[OPTDecoderLayer] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ), + ] + ) + + policy[OPTAttention] = ModulePolicyDescription( + attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ], + ) # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), - policy=policy, - target_key=OPTDecoder) - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription(suffix="self_attn_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True), - SubModuleReplacementDescription(suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True) - ], - policy=policy, - target_key=OPTDecoderLayer) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + ), + policy=policy, + target_key=OPTDecoder, + ) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True + ), + ], + policy=policy, + target_key=OPTDecoderLayer, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_opt_flash_attention_forward(), - }, - policy=policy, - target_key=OPTAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_opt_flash_attention_forward(), + }, + policy=policy, + target_key=OPTAttention, + ) # use jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_opt_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=OPTDecoderLayer) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_opt_decoder_layer_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=OPTDecoderLayer, + ) return policy @@ -128,7 +145,7 @@ def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None - if self.model.__class__.__name__ == 'OPTModel': + if self.model.__class__.__name__ == "OPTModel": module = self.model.decoder else: module = self.model.model.decoder @@ -149,24 +166,23 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'OPTModel': + if self.model.__class__.__name__ == "OPTModel": module = self.model.decoder else: module = self.model.model.decoder layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) class OPTModelPolicy(OPTPolicy): - def __init__(self) -> None: super().__init__() @@ -175,9 +191,9 @@ def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=OPTModel, - new_forward=OPTPipelineForwards.opt_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[nn.Module]: @@ -189,20 +205,22 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class OPTForCausalLMPolicy(OPTPolicy): - def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - policy=policy, - target_key=OPTForCausalLM) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ), + policy=policy, + target_key=OPTForCausalLM, + ) if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=OPTForCausalLM, - new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy + ) return policy @@ -223,7 +241,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: def postprocess(self): if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None: binding_map = { - 'model.decoder.embed_tokens': 'lm_head', + "model.decoder.embed_tokens": "lm_head", } for k, v in binding_map.items(): @@ -235,7 +253,6 @@ def postprocess(self): class OPTForSequenceClassificationPolicy(OPTPolicy): - def __init__(self) -> None: super().__init__() @@ -244,9 +261,11 @@ def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=OPTForSequenceClassification, - new_forward=OPTPipelineForwards.opt_for_sequence_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OPTForSequenceClassification, + new_forward=OPTPipelineForwards.opt_for_sequence_classification_forward, + policy=policy, + ) return policy @@ -262,7 +281,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class OPTForQuestionAnsweringPolicy(OPTPolicy): - def __init__(self) -> None: super().__init__() @@ -271,9 +289,11 @@ def module_policy(self): policy = super().module_policy() if self.pipeline_stage_manager: - self.set_pipeline_forward(model_cls=OPTForQuestionAnswering, - new_forward=OPTPipelineForwards.opt_for_question_answering_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=OPTForQuestionAnswering, + new_forward=OPTPipelineForwards.opt_for_question_answering_forward, + policy=policy, + ) return policy diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index 9753d5a737b9..58a8500e3863 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -1,16 +1,12 @@ -import torch.nn as nn - import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['SamPolicy', 'SamModelPolicy'] +__all__ = ["SamPolicy", "SamModelPolicy"] class SamPolicy(Policy): - def config_sanity_check(self): pass @@ -20,7 +16,6 @@ def preprocess(self): def module_policy(self): from transformers.models.sam.modeling_sam import ( SamAttention, - SamFeedForward, SamTwoWayAttentionBlock, SamTwoWayTransformer, SamVisionAttention, @@ -30,36 +25,37 @@ def module_policy(self): policy = {} if self.shard_config.enable_tensor_parallelism: - policy[SamVisionLayer] = ModulePolicyDescription(attribute_replacement={ - "attn.num_attention_heads": - self.model.config.vision_config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.qkv", - target_module=col_nn.FusedLinear1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.lin1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.lin2", - target_module=col_nn.Linear1D_Row, - ) - ]) + policy[SamVisionLayer] = ModulePolicyDescription( + attribute_replacement={ + "attn.num_attention_heads": self.model.config.vision_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.qkv", + target_module=col_nn.FusedLinear1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.lin1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.lin2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) policy[SamTwoWayAttentionBlock] = ModulePolicyDescription( attribute_replacement={ - "self_attn.num_attention_heads": - self.model.config.mask_decoder_config.num_attention_heads // - self.shard_config.tensor_parallel_size, + "self_attn.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads + // self.shard_config.tensor_parallel_size, }, sub_module_replacement=[ SubModuleReplacementDescription( @@ -118,97 +114,112 @@ def module_policy(self): suffix="cross_attn_image_to_token.out_proj", target_module=col_nn.Linear1D_Row, ), - ]) - policy[SamTwoWayTransformer] = ModulePolicyDescription(attribute_replacement={ - "final_attn_token_to_image.num_attention_heads": - self.model.config.mask_decoder_config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="final_attn_token_to_image.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="final_attn_token_to_image.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="final_attn_token_to_image.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="final_attn_token_to_image.out_proj", - target_module=col_nn.Linear1D_Row, - ) - ]) + ], + ) + policy[SamTwoWayTransformer] = ModulePolicyDescription( + attribute_replacement={ + "final_attn_token_to_image.num_attention_heads": self.model.config.mask_decoder_config.num_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="final_attn_token_to_image.out_proj", + target_module=col_nn.Linear1D_Row, + ), + ], + ) # add `DropoutForParallelInput` layer to replace the useage of `nn.functional.dropout` - policy[SamVisionAttention] = ModulePolicyDescription(attribute_replacement={ - "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout) - }, - method_replacement={"forward": forward_fn()}, - sub_module_replacement=[]) + policy[SamVisionAttention] = ModulePolicyDescription( + attribute_replacement={ + "dropout_layer": col_nn.DropoutForParallelInput(self.model.config.vision_config.attention_dropout) + }, + method_replacement={"forward": forward_fn()}, + sub_module_replacement=[], + ) # optimization configuration if self.shard_config.enable_fused_normalization: # Handle SamVisionLayer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm2", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=SamVisionLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=SamVisionLayer, + ) # Handle SamTwoWayAttentionBlock - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm1", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm2", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm3", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layer_norm4", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=SamTwoWayAttentionBlock) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm3", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layer_norm4", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=SamTwoWayAttentionBlock, + ) # Handle SamTwoWayTransformer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm_final_attn", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=SamTwoWayTransformer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm_final_attn", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=SamTwoWayTransformer, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_sam_flash_attention_forward(), - }, - policy=policy, - target_key=SamAttention) - self.append_or_create_method_replacement(description={ - 'forward': get_sam_vision_flash_attention_forward(), - }, - policy=policy, - target_key=SamVisionAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_sam_flash_attention_forward(), + }, + policy=policy, + target_key=SamAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_sam_vision_flash_attention_forward(), + }, + policy=policy, + target_key=SamVisionAttention, + ) return policy @@ -218,6 +229,5 @@ def postprocess(self): # SamModel class SamModelPolicy(SamPolicy): - def __init__(self) -> None: super().__init__() diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 92cbd3f72b83..74cc7337e9f1 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,6 +1,6 @@ import warnings from functools import partial -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Tuple import numpy as np from torch import Tensor, nn @@ -15,7 +15,6 @@ ) from colossalai.shardformer.policies.base_policy import ModulePolicyDescription -from .._utils import getattr_, setattr_ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.t5 import ( T5PipelineForwards, @@ -30,7 +29,6 @@ class T5BasePolicy(Policy): - def config_sanity_check(self): pass @@ -65,151 +63,181 @@ def module_policy(self): warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: - policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) - policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]) - policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]) - policy[T5Attention] = ModulePolicyDescription(attribute_replacement={ - "d_model": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "n_heads": - self.model.config.num_heads // self.shard_config.tensor_parallel_size, - "inner_dim": - self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="o", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="relative_attention_bias", - target_module=Embedding1D, - kwargs=dict(gather_output=False), - ignore_if_not_exist=True) - ]) - policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]) - policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi_0 ", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wi_1", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]) - policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wo", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]) + policy[T5Stack] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + ] + ) + policy[T5LayerSelfAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5LayerCrossAttention] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ] + ) + policy[T5Attention] = ModulePolicyDescription( + attribute_replacement={ + "d_model": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "n_heads": self.model.config.num_heads // self.shard_config.tensor_parallel_size, + "inner_dim": self.model.config.num_heads + * self.model.config.d_kv + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict(gather_output=False), + ignore_if_not_exist=True, + ), + ], + ) + policy[T5LayerFF] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseGatedActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0 ", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) + policy[T5DenseActDense] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ] + ) # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="layer_norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=T5LayerFF) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="layer_norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=T5LayerFF) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="layer_norm", target_module=FusedRMSNorm), - policy=policy, - target_key=T5LayerSelfAttention) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="layer_norm", target_module=FusedRMSNorm), - policy=policy, - target_key=T5LayerCrossAttention) - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="final_layer_norm", target_module=FusedRMSNorm), - policy=policy, - target_key=T5Stack) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerSelfAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerCrossAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5Stack, + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_t5_flash_attention_forward(), - }, - policy=policy, - target_key=T5Attention) + self.append_or_create_method_replacement( + description={ + "forward": get_t5_flash_attention_forward(), + }, + policy=policy, + target_key=T5Attention, + ) # use jit operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_T5_layer_ff_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=T5LayerFF) - self.append_or_create_method_replacement(description={ - 'forward': get_T5_layer_self_attention_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=T5LayerSelfAttention) - self.append_or_create_method_replacement(description={ - 'forward': get_T5_layer_cross_attention_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=T5LayerCrossAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_T5_layer_ff_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_T5_layer_self_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerSelfAttention, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_T5_layer_cross_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerCrossAttention, + ) return policy @@ -217,8 +245,9 @@ def postprocess(self): return self.model @staticmethod - def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int, - num_stages: int) -> Tuple[List[int], int]: + def distribute_t5_layers( + num_encoder_layers: int, num_decoder_layers: int, num_stages: int + ) -> Tuple[List[int], int]: """ Distribute t5 layers into stages when pipeline parallel is used. Return the layer distribution as a list and the starting stage of decoder. @@ -251,8 +280,9 @@ def objective(num_encoder_stages): return encoder_distribution + decoder_distribution, num_encoder_stages @staticmethod - def get_t5_stage_index(layers_per_stage: List[int], stage: int, - decoder_starting_stage: int) -> Tuple[bool, int, int]: + def get_t5_stage_index( + layers_per_stage: List[int], stage: int, decoder_starting_stage: int + ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder @@ -269,16 +299,18 @@ def get_held_layers(self) -> List[nn.Module]: model = self.model encoder = self.model.encoder - decoder = getattr(self.model, 'decoder', None) + decoder = getattr(self.model, "decoder", None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 held_layers = [] layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages) - start_idx, end_idx = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, - decoder_starting_stage) + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = T5BasePolicy.get_t5_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) if stage_manager.stage < decoder_starting_stage: # current stage is in t5's encoder @@ -303,47 +335,51 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager encoder = self.model.encoder - decoder = getattr(self.model, 'decoder', None) + decoder = getattr(self.model, "decoder", None) num_encoder_layers = len(encoder.block) num_decoder_layers = len(decoder.block) if decoder else 0 layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages) + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) class T5ModelPolicy(T5BasePolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers import T5Model + policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - policy=policy, - target_key=T5Model) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=T5Model, + ) if self.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy) @@ -356,9 +392,9 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block), - len(module.decoder.block), - stage_manager.num_stages) + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages + ) if id(module.decoder.embed_tokens.weight) == id(module.shared.weight): return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}] @@ -366,7 +402,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class T5ForConditionalGenerationPolicy(T5BasePolicy): - def __init__(self) -> None: super().__init__() @@ -376,22 +411,26 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription(suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ], - policy=policy, - target_key=T5ForConditionalGeneration) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ), + ], + policy=policy, + target_key=T5ForConditionalGeneration, + ) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=T5ForConditionalGeneration, - new_forward=T5PipelineForwards.t5_for_conditional_generation_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=T5ForConditionalGeneration, + new_forward=T5PipelineForwards.t5_for_conditional_generation_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[nn.Module]: @@ -404,9 +443,9 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block), - len(module.decoder.block), - stage_manager.num_stages) + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages + ) shared_params = [] shared_embedding = {} @@ -427,7 +466,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class T5EncoderPolicy(T5BasePolicy): - def __init__(self) -> None: super().__init__() @@ -437,17 +475,19 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - policy=policy, - target_key=T5EncoderModel) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=T5EncoderModel, + ) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=T5EncoderModel, - new_forward=T5PipelineForwards.t5_encoder_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=T5EncoderModel, new_forward=T5PipelineForwards.t5_encoder_model_forward, policy=policy + ) return policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index b4fb8692e684..270cdce9b091 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -16,11 +16,10 @@ ) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy'] +__all__ = ["ViTPolicy", "ViTModelPolicy", "ViTForImageClassificationPolicy", "ViTForMaskedImageModelingPolicy"] class ViTPolicy(Policy): - def config_sanity_check(self): pass @@ -28,8 +27,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - - from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTOutput, ViTSelfAttention policy = {} @@ -38,77 +36,85 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: - policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForReplicatedInput, - ) - ]) - - policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - ]) + policy[ViTEmbeddings] = ModulePolicyDescription( + attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ], + ) + + policy[ViTLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.attention.num_attention_heads": self.model.config.num_attention_heads + // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + ) # use flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_vit_flash_self_attention_forward(), - }, - policy=policy, - target_key=ViTSelfAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_vit_flash_self_attention_forward(), + }, + policy=policy, + target_key=ViTSelfAttention, + ) # use jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_vit_output_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=ViTOutput) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_vit_output_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=ViTOutput, + ) return policy def new_model_class(self): @@ -121,7 +127,7 @@ def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" - if self.model.__class__.__name__ == 'ViTModel': + if self.model.__class__.__name__ == "ViTModel": module = self.model else: module = self.model.vit @@ -138,22 +144,21 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict): if self.pipeline_stage_manager: stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'ViTModel': + if self.model.__class__.__name__ == "ViTModel": module = self.model else: module = self.model.vit layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=model_cls) + method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) # ViTModel class ViTModelPolicy(ViTPolicy): - def __init__(self) -> None: super().__init__() @@ -181,26 +186,29 @@ def get_held_layers(self) -> List[nn.Module]: # ViTForImageClassification class ViTForImageClassificationPolicy(ViTPolicy): - def module_policy(self): from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: new_item = { - ViTForImageClassification: - ModulePolicyDescription(sub_module_replacement=[ + ViTForImageClassification: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) } policy.update(new_item) if self.shard_config.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) - self.set_pipeline_forward(model_cls=ViTForImageClassification, - pipeline_forward=ViTForImageClassification_pipeline_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=ViTForImageClassification, + pipeline_forward=ViTForImageClassification_pipeline_forward, + policy=policy, + ) return policy @@ -219,7 +227,6 @@ def get_held_layers(self) -> List[nn.Module]: # ViTForMaskedImageModeling class ViTForMaskedImageModelingPolicy(ViTPolicy): - def __init__(self) -> None: super().__init__() @@ -230,9 +237,11 @@ def module_policy(self): if self.shard_config.pipeline_stage_manager is not None: self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) - self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling, - pipeline_forward=ViTForMaskedImageModeling_pipeline_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=ViTForMaskedImageModeling, + pipeline_forward=ViTForMaskedImageModeling_pipeline_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 31ba82166b31..d9af2461cdb8 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -8,7 +8,6 @@ import colossalai.shardformer.layer as col_nn -from .._utils import getattr_, setattr_ from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.whisper import ( WhisperPipelineForwards, @@ -19,13 +18,14 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ - 'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', - 'WhisperForAudioClassificationPolicy' + "WhisperPolicy", + "WhisperModelPolicy", + "WhisperForConditionalGenerationPolicy", + "WhisperForAudioClassificationPolicy", ] class WhisperPolicy(Policy): - def config_sanity_check(self): pass @@ -55,179 +55,197 @@ def module_policy(self): if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False warnings.warn( - "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + "Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) - #TODO using the jit fused add_and_dropout affect the accuracy + # TODO using the jit fused add_and_dropout affect the accuracy if self.shard_config.enable_jit_fused: self.shard_config.enable_jit_fused = False warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.") if self.shard_config.enable_tensor_parallelism: - policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={ - "self_attn.embed_dim": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.out_proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="fc1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=col_nn.Linear1D_Row, - ), - ]) - - policy[WhisperDecoderLayer] = ModulePolicyDescription(attribute_replacement={ - "self_attn.embed_dim": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.decoder_attention_heads // self.shard_config.tensor_parallel_size, - "encoder_attn.embed_dim": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "encoder_attn.num_heads": - self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.out_proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="encoder_attn.q_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="encoder_attn.k_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="encoder_attn.v_proj", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="encoder_attn.out_proj", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="fc1", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=col_nn.Linear1D_Row, - ), - ]) - - policy[WhisperDecoder] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=col_nn.VocabParallelEmbedding1D, - ), - ]) + policy[WhisperEncoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.encoder_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + + policy[WhisperDecoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.decoder_attention_heads + // self.shard_config.tensor_parallel_size, + "encoder_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size, + "encoder_attn.num_heads": self.model.config.encoder_attention_heads + // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.q_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.k_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.v_proj", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="encoder_attn.out_proj", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="fc1", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=col_nn.Linear1D_Row, + ), + ], + ) + + policy[WhisperDecoder] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ] + ) # optimization configuration if self.shard_config.enable_fused_normalization: # Handle encoder layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperEncoderLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=WhisperEncoderLayer, + ) # Handle decoder layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=col_nn.FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperDecoderLayer) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=col_nn.FusedLayerNorm, + ), + ], + policy=policy, + target_key=WhisperDecoderLayer, + ) # handle encoder layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperEncoder) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperEncoder, + ) # handle decoder layer - self.append_or_create_submodule_replacement(description=[ - SubModuleReplacementDescription( - suffix="layer_norm", - target_module=col_nn.FusedLayerNorm, - ) - ], - policy=policy, - target_key=WhisperDecoder) + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="layer_norm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=WhisperDecoder, + ) # enable flash attention if self.shard_config.enable_flash_attention: - self.append_or_create_method_replacement(description={ - 'forward': get_whisper_flash_attention_forward(), - }, - policy=policy, - target_key=WhisperAttention) + self.append_or_create_method_replacement( + description={ + "forward": get_whisper_flash_attention_forward(), + }, + policy=policy, + target_key=WhisperAttention, + ) # use jit fused operator if self.shard_config.enable_jit_fused: - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_whisper_decoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=WhisperDecoderLayer) - self.append_or_create_method_replacement(description={ - 'forward': get_jit_fused_whisper_encoder_layer_forward(), - 'dropout_add': get_jit_fused_dropout_add_func(), - }, - policy=policy, - target_key=WhisperEncoderLayer) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_whisper_decoder_layer_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperDecoderLayer, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_whisper_encoder_layer_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=WhisperEncoderLayer, + ) return policy @@ -236,10 +254,13 @@ def add_lm_head_policy(self, base_policy): # optimize for tensor parallelism if self.shard_config.enable_tensor_parallelism: - self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( - suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - policy=base_policy, - target_key=WhisperForConditionalGeneration) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ), + policy=base_policy, + target_key=WhisperForConditionalGeneration, + ) return base_policy @@ -247,8 +268,9 @@ def postprocess(self): return self.model @staticmethod - def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int, - num_stages: int) -> Tuple[List[int], int]: + def distribute_whisper_layers( + num_encoder_layers: int, num_decoder_layers: int, num_stages: int + ) -> Tuple[List[int], int]: """ Distribute whisper layers into stages when pipeline parallel is used. Return the layer distribution as a list and the starting stage of decoder. @@ -281,8 +303,9 @@ def objective(num_encoder_stages): return encoder_distribution + decoder_distribution, num_encoder_stages @staticmethod - def get_whisper_stage_index(layers_per_stage: List[int], stage: int, - decoder_starting_stage: int) -> Tuple[bool, int, int]: + def get_whisper_stage_index( + layers_per_stage: List[int], stage: int, decoder_starting_stage: int + ) -> Tuple[bool, int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder @@ -293,13 +316,12 @@ def get_whisper_stage_index(layers_per_stage: List[int], stage: int, return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) def get_held_layers(self) -> List[nn.Module]: - assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'WhisperModel': + if self.model.__class__.__name__ == "WhisperModel": model = self.model - elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + elif self.model.__class__.__name__ == "WhisperForConditionalGeneration": model = self.model.model else: model = None @@ -320,9 +342,11 @@ def get_held_layers(self) -> List[nn.Module]: held_layers = [] layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages) - start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, - decoder_starting_stage) + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + start_idx, end_idx = WhisperPolicy.get_whisper_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) if stage_manager.stage < decoder_starting_stage: # current stage is in whisper's encoder @@ -347,14 +371,14 @@ def get_held_layers(self) -> List[nn.Module]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'WhisperModel': + if self.model.__class__.__name__ == "WhisperModel": model = self.model - elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration': + elif self.model.__class__.__name__ == "WhisperForConditionalGeneration": model = self.model.model else: model = None @@ -373,34 +397,37 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli num_decoder_layers = 0 layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - num_encoder_layers, num_decoder_layers, stage_manager.num_stages) - stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage, - decoder_starting_stage) + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) + stage_index = WhisperPolicy.get_whisper_stage_index( + layers_per_stage, stage_manager.stage, decoder_starting_stage + ) method_replacement = { - 'forward': - partial(new_forward, - stage_manager=stage_manager, - stage_index=stage_index, - decoder_starting_stage=decoder_starting_stage) + "forward": partial( + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + decoder_starting_stage=decoder_starting_stage, + ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) # WhisperModel class WhisperModelPolicy(WhisperPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers import WhisperModel + policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=WhisperModel, - new_forward=WhisperPipelineForwards.whisper_model_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy + ) return policy @@ -414,19 +441,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # WhisperForConditionalGeneration class WhisperForConditionalGenerationPolicy(WhisperPolicy): - def __init__(self) -> None: super().__init__() def module_policy(self): from transformers import WhisperForConditionalGeneration + policy = super().module_policy() policy = self.add_lm_head_policy(policy) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration, - new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=WhisperForConditionalGeneration, + new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward, + policy=policy, + ) return policy def postprocess(self): @@ -457,8 +486,9 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: stage_manager = self.pipeline_stage_manager if stage_manager is not None and stage_manager.num_stages > 1: - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers, - stage_manager.num_stages) + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + num_encoder_layers, num_decoder_layers, stage_manager.num_stages + ) shared_params = [] shared_embedding = {} if id(module.proj_out) == id(model.decoder.embed_tokens): @@ -472,7 +502,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: # WhisperForAudioClassification class WhisperForAudioClassificationPolicy(WhisperPolicy): - def __init__(self) -> None: super().__init__() @@ -481,12 +510,15 @@ def preprocess(self): def module_policy(self): from transformers import WhisperForAudioClassification + policy = super().module_policy() if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=WhisperForAudioClassification, - new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=WhisperForAudioClassification, + new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward, + policy=policy, + ) return policy def get_held_layers(self) -> List[nn.Module]: diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index 7abdd45ec7c5..acf8a95a41ca 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -2,4 +2,4 @@ from .sharder import ModelSharder from .shardformer import ShardFormer -__all__ = ['ShardConfig', 'ModelSharder', 'ShardFormer'] +__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0b6e1640952b..6935288130c9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -6,7 +6,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager -__all__ = ['ShardConfig'] +__all__ = ["ShardConfig"] @dataclass @@ -45,7 +45,8 @@ def tensor_parallel_size(self): def __post_init__(self): if not self.enable_tensor_parallelism and self.enable_sequence_parallelism: raise ValueError( - "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True") + "enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True" + ) if not self.enable_sequence_parallelism and self.enable_sequence_overlap: raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True") if not self.enable_tensor_parallelism: diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 7592069a2dd9..1bed850c6581 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -12,7 +12,7 @@ from .shard_config import ShardConfig from .utils import set_tensors_to_none -__all__ = ['ModelSharder', 'shard_model'] +__all__ = ["ModelSharder", "shard_model"] class ModelSharder(object): @@ -64,13 +64,15 @@ def _replace_module(self, include: Optional[Set[nn.Module]] = None) -> None: param_replacement = module_description.param_replacement sub_module_replacement = module_description.sub_module_replacement method_replacement = module_description.method_replacement - self._recursive_replace_layer(self.model, - layer_cls, - attr_replacement, - param_replacement, - method_replacement, - sub_module_replacement, - include=include) + self._recursive_replace_layer( + self.model, + layer_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include, + ) def _recursive_replace_layer( self, @@ -94,8 +96,9 @@ def _recursive_replace_layer( sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None """ - if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ - (module.__class__ == origin_cls): + if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or ( + module.__class__ == origin_cls + ): if attr_replacement is not None: self._replace_attr(module, attr_replacement) @@ -109,13 +112,15 @@ def _recursive_replace_layer( self._replace_sub_module(module, sub_module_replacement, include) for name, child in module.named_children(): - self._recursive_replace_layer(child, - origin_cls, - attr_replacement, - param_replacement, - method_replacement, - sub_module_replacement, - include=include) + self._recursive_replace_layer( + child, + origin_cls, + attr_replacement, + param_replacement, + method_replacement, + sub_module_replacement, + include=include, + ) def _replace_attr( self, @@ -153,10 +158,12 @@ def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Calla bound_method = MethodType(new_method, module) setattr(module, method_name, bound_method) - def _replace_sub_module(self, - org_layer: nn.Module, - sub_module_replacement: List[SubModuleReplacementDescription], - include: Optional[Set[nn.Module]] = None) -> None: + def _replace_sub_module( + self, + org_layer: nn.Module, + sub_module_replacement: List[SubModuleReplacementDescription], + include: Optional[Set[nn.Module]] = None, + ) -> None: r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict @@ -170,7 +177,7 @@ def _replace_sub_module(self, target_module = description.target_module kwargs = {} if description.kwargs is None else description.kwargs - assert target_module is not None, 'target_module should not be None' + assert target_module is not None, "target_module should not be None" native_sub_module = getattr_(org_layer, suffix, ignore=True) @@ -178,8 +185,9 @@ def _replace_sub_module(self, if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include): continue - assert not isinstance(native_sub_module, target_module), \ - f"The module with suffix {suffix} has been replaced, please check the policy" + assert not isinstance( + native_sub_module, target_module + ), f"The module with suffix {suffix} has been replaced, please check the policy" # if it is None and we are allowed to ignore this module # just skip @@ -187,9 +195,9 @@ def _replace_sub_module(self, continue try: - replace_layer = target_module.from_native_module(native_sub_module, - self.shard_config.tensor_parallel_process_group, - **kwargs) + replace_layer = target_module.from_native_module( + native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs + ) except Exception as e: raise RuntimeError( f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" @@ -200,7 +208,6 @@ def _replace_sub_module(self, setattr_(org_layer, suffix, replace_layer) def _get_recursive_held_layers(self, held_layers: Optional[List[nn.Module]]) -> Optional[List[nn.Module]]: - def collect_sub_modules(module: nn.Module): if module is None: return diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index 099376d931e8..9ed149f33f2f 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -5,7 +5,14 @@ from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor __all__ = [ - 'ColoTensor', 'convert_parameter', 'named_params_with_colotensor', 'ColoParameter', 'ColoParamOpHook', - 'ColoParamOpHookManager', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', - 'merge_same_dim_mesh_list' + "ColoTensor", + "convert_parameter", + "named_params_with_colotensor", + "ColoParameter", + "ColoParamOpHook", + "ColoParamOpHookManager", + "CommSpec", + "CollectiveCommPattern", + "convert_dim_partition_dict", + "merge_same_dim_mesh_list", ] diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 076661a08824..5712505ae2ff 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -11,7 +11,7 @@ def is_no_hook_op(func) -> bool: - return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS + return func.__name__.startswith("__") and func not in WHITE_LIST_FUNCS def filter_colo_parameters(*args, **kwargs): @@ -36,18 +36,16 @@ def get_colo_parameters(element) -> None: def replace_args(args, kwargs, new_args): - args = new_args[:len(args)] - for k, v in zip(kwargs.keys(), new_args[len(args):]): + args = new_args[: len(args)] + for k, v in zip(kwargs.keys(), new_args[len(args) :]): kwargs[k] = v return tuple(args), kwargs class ColoParameter(ColoTensor, torch.nn.Parameter): - r"""A kind of ColoTensor to be considered as a module parameter. + r"""A kind of ColoTensor to be considered as a module parameter.""" - """ - - def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter': + def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> "ColoParameter": if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index a20a1444a406..c2de9abce371 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -7,7 +7,7 @@ torch.Tensor.add_: torch.Tensor.add, torch.Tensor.sub_: torch.Tensor.sub, torch.Tensor.mul_: torch.Tensor.mul, - torch.Tensor.div_: torch.Tensor.div + torch.Tensor.div_: torch.Tensor.div, } @@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]: Tensor._base.__get__, Tensor.grad.__get__, Tensor._grad.__get__, - Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor + Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor } @@ -37,17 +37,18 @@ def _convert_output(output, func): class ColoTensor(torch.Tensor): - """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. + """Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. It is only used to trigger the torch function hook. Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. """ - torch_major = int(torch.__version__.split('.')[0]) - torch_minor = int(torch.__version__.split('.')[1]) - def __new__(cls, data: torch.Tensor) -> 'ColoTensor': + torch_major = int(torch.__version__.split(".")[0]) + torch_minor = int(torch.__version__.split(".")[1]) + + def __new__(cls, data: torch.Tensor) -> "ColoTensor": """ The signature of the __new__ has to be consistent with the torch.Tensor. @@ -74,7 +75,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): # we have to capture the `backward` function # and make sure that it does not in `torch._C.DisableTorchFunction()` context if func is torch.Tensor.backward: - assert len(args) == 1 # only has 1 parameter + assert len(args) == 1 # only has 1 parameter backward_tensor = torch.Tensor(args[0]) tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} return backward_tensor.backward(**tensor_kwargs) @@ -83,8 +84,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if func in INPALCE_MAPPING: func = INPALCE_MAPPING[func] # set the 'inplace' kwargs to False - if 'inplace' in kwargs: - kwargs['inplace'] = False + if "inplace" in kwargs: + kwargs["inplace"] = False with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 204f81343199..de0cba26b52a 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -7,15 +7,15 @@ from torch.distributed import ReduceOp __all__ = [ - 'CollectiveCommPattern', - 'CommSpec', + "CollectiveCommPattern", + "CommSpec", ] def _all_gather(tensor, comm_spec): - ''' + """ Implement all gather operation on device mesh based on information provided by comm_spec. - ''' + """ process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() process_group = process_groups[comm_spec.logical_process_axis] @@ -31,9 +31,9 @@ def _all_gather(tensor, comm_spec): def _split(tensor, comm_spec): - ''' + """ Implement shard operation on device mesh based on information provided by comm_spec. - ''' + """ process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() process_group = process_groups[comm_spec.logical_process_axis] @@ -45,9 +45,9 @@ def _split(tensor, comm_spec): def _all_to_all(tensor, comm_spec): - ''' + """ Implement all to all operation on device mesh based on information provided by comm_spec. - ''' + """ process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() process_group = process_groups[comm_spec.logical_process_axis] world_size = dist.get_world_size(process_group) @@ -66,9 +66,9 @@ def _all_to_all(tensor, comm_spec): def _all_reduce(tensor, comm_spec, async_op=False): - ''' + """ Implement all reduce operation on device mesh based on information provided by comm_spec. - ''' + """ process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() process_group = process_groups[comm_spec.logical_process_axis] @@ -79,7 +79,7 @@ def _all_reduce(tensor, comm_spec, async_op=False): def _mix_gather(tensor, comm_spec): - ''' + """ Implement mix gather operation on device mesh based on information provided by comm_spec. Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather @@ -124,7 +124,7 @@ def _mix_gather(tensor, comm_spec): leading_group_dim = 1 process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] - ''' + """ total_slices = comm_spec.device_mesh.shape[0] tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] leading_group_dim = comm_spec.logical_process_axes[0] @@ -155,15 +155,16 @@ def _mix_gather(tensor, comm_spec): torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1]) ] for i in range(cat_slice[1]): - tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]), - comm_spec.gather_dim[0]).contiguous() + tmp_tensor_list[i] = torch.cat( + tuple(tensor_list[i * cat_slice[0] : (i + 1) * cat_slice[0]]), comm_spec.gather_dim[0] + ).contiguous() output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous() return output def _mix_split(tensor, comm_spec): - ''' + """ Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent) Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension. @@ -177,7 +178,7 @@ def _mix_split(tensor, comm_spec): # [[0, 1, 2, 3], # [4, 5, 6, 7]] # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} - ''' + """ mesh_shape = comm_spec.device_meshes.shape dim = comm_spec.gather_dim total_slices = comm_spec.device_mesh.shape[0] @@ -316,11 +317,13 @@ def symbolic(graph, input_): @staticmethod def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) - comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - sharding_spec=comm_spec.sharding_spec, - gather_dim=comm_spec.shard_dim, - shard_dim=comm_spec.gather_dim, - logical_process_axis=comm_spec.logical_process_axis) + comm_spec_for_backward = CommSpec( + comm_pattern=comm_spec.comm_pattern, + sharding_spec=comm_spec.sharding_spec, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis, + ) ctx.comm_spec = comm_spec_for_backward return output @@ -330,7 +333,6 @@ def backward(ctx, grad_outputs): class _MixGatherForwardMixSplitBackward(torch.autograd.Function): - @staticmethod def symbolic(graph, input_): return _mix_gather(input_) @@ -370,16 +372,16 @@ def mixgather_forward_split_backward(input_, comm_spec): class CollectiveCommPattern(Enum): - GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' - ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' - SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' - ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' - IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd" + ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd" + SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd" + ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd" + IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd" MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" class CommSpec: - ''' + """ Communication spec is used to record the communication action. It has two main functions: 1. Compute the communication cost which will be used in auto parallel solver. 2. Convert the communication spec to real action which will be used in runtime. @@ -393,16 +395,18 @@ class CommSpec: gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. - ''' - - def __init__(self, - comm_pattern, - sharding_spec, - gather_dim=None, - shard_dim=None, - logical_process_axis=None, - forward_only=False, - mix_gather=False): + """ + + def __init__( + self, + comm_pattern, + sharding_spec, + gather_dim=None, + shard_dim=None, + logical_process_axis=None, + forward_only=False, + mix_gather=False, + ): self.comm_pattern = comm_pattern self.sharding_spec = sharding_spec self.gather_dim = gather_dim @@ -449,14 +453,14 @@ def __repr__(self): res_list.append(f"gather_dim:{self.gather_dim}, ") res_list.append(f"logical_process_asex:{self.logical_process_axes})") - return ''.join(res_list) + return "".join(res_list) def get_comm_cost(self): - ''' + """ For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is zero. - ''' + """ comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) cost_dict = {} if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: @@ -500,13 +504,13 @@ def get_comm_cost(self): return cost_dict def covert_spec_to_action(self, tensor): - ''' + """ Convert CommSpec into runtime action, implement real collection communication to target tensor. The collection communication action is directed by the CommSpec. Argument: tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. - ''' + """ if self.comm_pattern in pattern_to_func_dict: tensor = pattern_to_func_dict[self.comm_pattern](tensor, self) else: diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index 3ae38a12555b..fad5101d380c 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -21,8 +21,23 @@ from .sharding_spec import ShardingSpec __all__ = [ - 'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise', - 'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh', - 'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization', - 'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec' + "is_distributed_tensor", + "distribute_tensor", + "to_global", + "is_sharded", + "shard_rowwise", + "shard_colwise", + "sharded_tensor_to_param", + "compute_global_numel", + "get_sharding_spec", + "get_global_shape", + "get_device_mesh", + "redistribute", + "get_layout", + "is_customized_distributed_tensor", + "distribute_tensor_with_customization", + "to_global_for_customized_distributed_tensor", + "customized_distributed_tensor_to_param", + "Layout", + "ShardingSpec", ] diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 9848e4ca423e..178bac428ea9 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -44,7 +44,7 @@ def is_sharded(dtensor: torch.Tensor) -> bool: Returns: bool: True if the tensor is sharded, False otherwise. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return list(dtensor.shape) == list(dtensor.dist_layout.global_shape) @@ -77,8 +77,10 @@ def new_clone(self, *args, **kwargs): return dtensor -def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: - ''' +def _construct_default_sharding_spec( + tensor: torch.Tensor, +) -> ShardingSpec: + """ Construct the default sharding specification for the tensor. Args: @@ -86,14 +88,14 @@ def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: Returns: A `ShardingSpec` object without any sharding specified. - ''' + """ return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) def _apply_layout(tensor, layout): - ''' + """ Apply the layout to the local tensor during initializing process. - ''' + """ # layout converter requires a source and target laytout # we construct the source layer for an unsharded tensor # and use self.dist_layer as the targer layout for the sharded tensor @@ -115,7 +117,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp Returns: torch.Tensor: The distributed tensor. """ - assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape) # shard tensor @@ -128,7 +130,7 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: - ''' + """ Convert the layout of the tensor from source_spec to target_spec. This will update the `local_tensor` and `dist_layout` in place. @@ -136,13 +138,13 @@ def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: dtensor (torch.Tensor): the distributed tensor to be converted. device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices. target_layout (Layout): the target layout specification. - ''' - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + """ + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." global_shape = get_global_shape(dtensor) target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) - resharded_tensor = layout_converter.apply(tensor=dtensor, - source_layout=dtensor.dist_layout, - target_layout=target_layout) + resharded_tensor = layout_converter.apply( + tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout + ) return resharded_tensor @@ -157,7 +159,7 @@ def to_global(dtensor: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: the global tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." layout_converter = LayoutConverter() global_sharding_spec = ShardingSpec(dtensor.dim(), {}) @@ -193,7 +195,7 @@ def shard_rowwise( if isinstance(group_or_device_mesh, ProcessGroup): device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) else: - assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding." device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) @@ -222,7 +224,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup if isinstance(group_or_device_mesh, ProcessGroup): device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) else: - assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding." device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) @@ -230,7 +232,7 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) # make it distributed as well @@ -241,7 +243,7 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None: - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." param.data = dtensor # make it distributed as well param.dist_layout = dtensor.dist_layout @@ -258,7 +260,7 @@ def compute_global_numel(dtensor: torch.Tensor) -> int: Returns: int: The global number of elements in the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." numel = reduce(operator.mul, dtensor.dist_layout.global_shape) return numel @@ -274,7 +276,7 @@ def get_layout(dtensor: torch.Tensor) -> Layout: Layout: The layout of the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout @@ -288,7 +290,7 @@ def get_global_shape(dtensor: torch.Tensor) -> torch.Size: Returns: torch.Size: The global shape of the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout.global_shape @@ -302,7 +304,7 @@ def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh: Returns: DeviceMesh: The device mesh of the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout.device_mesh @@ -316,7 +318,7 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec: Returns: ShardingSpec: The sharding spec of the distributed tensor. """ - assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor." return dtensor.dist_layout.sharding_spec @@ -335,7 +337,7 @@ def is_customized_distributed_tensor(tensor: torch.Tensor): Returns: bool: Whether the given tensor is a customized distributed tensor. """ - return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn') + return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn") def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: @@ -402,9 +404,9 @@ def gather_fn(tensor): Returns: torch.Tensor: The distributed tensor. """ - assert callable(shard_fn), 'The shard_fn must be callable.' - assert callable(gather_fn), 'The gather_fn must be callable.' - assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + assert callable(shard_fn), "The shard_fn must be callable." + assert callable(gather_fn), "The gather_fn must be callable." + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." sharded_tensor = shard_fn(tensor) @@ -428,7 +430,7 @@ def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch. Returns: torch.Tensor: The global tensor. """ - assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." return dtensor.gather_fn(dtensor) @@ -436,7 +438,7 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: """ Convert the given customized distributed tensor to a parameter. """ - assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) @@ -451,7 +453,7 @@ def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param """ Convert the given customized distributed tensor to an existing parameter. """ - assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor." param.data = dtensor.data param.shard_fn = dtensor.shard_fn diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 6158d0bfe2ad..8f5b52aab8f8 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -6,22 +6,22 @@ from torch.distributed import ReduceOp __all__ = [ - 'CollectiveCommPattern', - 'CommSpec', + "CollectiveCommPattern", + "CommSpec", ] class CollectiveCommPattern(Enum): - GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' - ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' - SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' - ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' - IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd" + ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd" + SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd" + ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd" + IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd" MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" class CommSpec: - ''' + """ Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the communication method, process_group_dict to determine the process groups, gather_dim and shard_dim @@ -33,14 +33,16 @@ class CommSpec: gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. - ''' - - def __init__(self, - comm_pattern: CollectiveCommPattern, - process_group_dict: Dict, - gather_dim: int = None, - shard_dim: int = None, - logical_process_axis: int = None): + """ + + def __init__( + self, + comm_pattern: CollectiveCommPattern, + process_group_dict: Dict, + gather_dim: int = None, + shard_dim: int = None, + logical_process_axis: int = None, + ): self.comm_pattern = comm_pattern self.gather_dim = gather_dim self.shard_dim = shard_dim @@ -71,16 +73,16 @@ def __repr__(self): res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") - return ''.join(res_list) + return "".join(res_list) def covert_spec_to_action(self, tensor): - ''' + """ Convert CommSpec into runtime action, implement real collection communication to target tensor. The collection communication action is directed by the CommSpec. Argument: tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. - ''' + """ if self.comm_pattern in pattern_to_func_dict: tensor = pattern_to_func_dict[self.comm_pattern](tensor, self) else: @@ -89,9 +91,9 @@ def covert_spec_to_action(self, tensor): def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): - ''' + """ Implement all gather operation on device mesh based on information provided by comm_spec. - ''' + """ process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] world_size = dist.get_world_size(process_group) tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] @@ -103,9 +105,9 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): def _split(tensor: torch.Tensor, comm_spec: CommSpec): - ''' + """ Implement shard operation on device mesh based on information provided by comm_spec. - ''' + """ process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) @@ -115,9 +117,9 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec): def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): - ''' + """ Implement all to all operation on device mesh based on information provided by comm_spec. - ''' + """ process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] world_size = dist.get_world_size(process_group) new_shape = list(tensor.shape) @@ -134,9 +136,9 @@ def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): - ''' + """ Implement all reduce operation on device mesh based on information provided by comm_spec. - ''' + """ process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] if not tensor.is_contiguous(): tensor = tensor.contiguous() @@ -256,11 +258,13 @@ def symbolic(graph, input_): @staticmethod def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) - comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_group_dict=comm_spec.process_group_dict, - gather_dim=comm_spec.shard_dim, - shard_dim=comm_spec.gather_dim, - logical_process_axis=comm_spec.logical_process_axis) + comm_spec_for_backward = CommSpec( + comm_pattern=comm_spec.comm_pattern, + process_group_dict=comm_spec.process_group_dict, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis, + ) ctx.comm_spec = comm_spec_for_backward return output diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index a35b2f43e44b..6d4c5dbe3c09 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -25,15 +25,16 @@ def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_ self._sanity_check() def __hash__(self) -> int: - return hash(f'{self.sharding_spec}') + return hash(f"{self.sharding_spec}") def get_sharded_shape_per_device(self): sharded_shape = list(self.global_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) - assert sharded_shape[ - dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + assert ( + sharded_shape[dim] % shard_partitions == 0 + ), f"Cannot shard dimension {dim} into {shard_partitions} partitions." sharded_shape[dim] //= shard_partitions return torch.Size(sharded_shape) @@ -49,7 +50,8 @@ def _sanity_check(self): dim_check_list.remove(element) else: raise DuplicatedShardingDimensionError( - f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}." + ) # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): @@ -61,5 +63,5 @@ def _sanity_check(self): if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( - f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices." ) diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 528ed7901c4f..e031e0472b0b 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -14,7 +14,7 @@ from .sharding_spec import ShardingSpec from .utils import get_comm_cost -__all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_options'] +__all__ = ["LayoutConverter", "LayoutConverterOptions", "set_layout_converting_options"] @dataclass @@ -22,8 +22,8 @@ class LayoutConverterOptions: """ LayoutConverterOptions is a dataclass which specifies the preferences for layout converting. """ + # TODO: layout converter option is not implemented yet - pass def set_layout_converting_options(options: LayoutConverterOptions): @@ -63,7 +63,7 @@ def forward_only(self, value): self._forward_only = value def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with single all-gather operation. For the all-gather operation, we just care about the S dimension. @@ -96,7 +96,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co Output: [R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0) [S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec @@ -125,16 +125,19 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co comm_pattern, process_group_dict=process_group_dict, gather_dim=gather_dim, - # shard_dim will be used during backward + # shard_dim will be used during backward shard_dim=gather_dim, - logical_process_axis=logical_process_axis) + logical_process_axis=logical_process_axis, + ) # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) - new_layout = Layout(device_mesh=source_layout.device_mesh, - sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + new_layout = Layout( + device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + global_shape=source_layout.global_shape, + ) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -142,7 +145,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co return valid_spec_dict def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with single all-to-all operation. For the all-to-all operation, we just care about the pairs containing S dimension. @@ -176,7 +179,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com [S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1) [R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0) [S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1) - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD @@ -224,11 +227,13 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com gather_dim = b_index shard_dim = f_index logical_process_axis = b_target_pair[1][-1] - comm_spec = CommSpec(comm_pattern, - process_group_dict=process_group_dict, - gather_dim=gather_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis) + comm_spec = CommSpec( + comm_pattern, + process_group_dict=process_group_dict, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + ) new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) @@ -246,9 +251,11 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) - new_layout = Layout(device_mesh=source_layout.device_mesh, - sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + new_layout = Layout( + device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + global_shape=source_layout.global_shape, + ) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -256,7 +263,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com return valid_spec_dict def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with single shard operation. For the sharding operation, we just care about legal sharding dimensions. @@ -291,7 +298,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec [S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1) [S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) [S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1) - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec @@ -326,26 +333,31 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec shard_dim = index logical_process_axis = shard_list[-1] - comm_spec = CommSpec(comm_pattern, - process_group_dict=process_group_dict, - gather_dim=shard_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis) + comm_spec = CommSpec( + comm_pattern, + process_group_dict=process_group_dict, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + ) # generate new sharding spec try: - new_sharding_spec = ShardingSpec(dim_size=source_spec.dims, - dim_partition_dict=new_dim_partition_dict) - new_layout = Layout(device_mesh=source_layout.device_mesh, - sharding_spec=new_sharding_spec, - global_shape=source_layout.global_shape) + new_sharding_spec = ShardingSpec( + dim_size=source_spec.dims, dim_partition_dict=new_dim_partition_dict + ) + new_layout = Layout( + device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + global_shape=source_layout.global_shape, + ) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass return valid_spec_dict def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]: - ''' + """ Get all valid layouts from source_layout with one step transform. Note: @@ -358,16 +370,17 @@ def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, Return: valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform. - ''' + """ valid_spec_dict = {} valid_spec_dict.update(self.all_gather_transform_layouts(source_layout)) valid_spec_dict.update(self.all_to_all_transform_layout(source_layout)) valid_spec_dict.update(self.shard_transform_layout(source_layout)) return valid_spec_dict - def layout_converting(self, source_layout: Layout, - target_layout: Layout) -> Tuple[List[Layout], List[CommSpec], float]: - ''' + def layout_converting( + self, source_layout: Layout, target_layout: Layout + ) -> Tuple[List[Layout], List[CommSpec], float]: + """ This method will find a path to transform source_layout to target_layout with a greedy algorithm. The basic idea is: @@ -419,7 +432,7 @@ def layout_converting(self, source_layout: Layout, output: [R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R] - ''' + """ source_spec = source_layout.sharding_spec target_spec = target_layout.sharding_spec MAX_TRANSFORM_STEPS = 20 @@ -470,11 +483,11 @@ def layout_converting(self, source_layout: Layout, raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]: - ''' + """ Get the total communication cost of the layout converting process. - ''' + """ transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout) - total_cost = {'forward': 0.0, 'backward': 0.0, 'total': 0.0} + total_cost = {"forward": 0.0, "backward": 0.0, "total": 0.0} for layout, comm_spec in zip(transform_path, comm_action_sequence): cost_dict = get_comm_cost(layout, comm_spec, self.forward_only) for key in total_cost: @@ -482,7 +495,7 @@ def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> D return total_cost def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor: - ''' + """ Apply target_layout to tensor with source layout, the transform path is generated by the layout_converting method. @@ -542,7 +555,7 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo [1.], [3.], [3.]]) - ''' + """ _, comm_action_sequence = self.layout_converting(source_layout, target_layout) for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 565012b58a03..2ac0ca73e4b8 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -4,16 +4,16 @@ from ..utils import merge_same_dim_mesh_list from .misc import ShardingOutOfIndexError -__all__ = ['DimSpec', 'ShardingException', 'ShardingSpec'] +__all__ = ["DimSpec", "ShardingException", "ShardingSpec"] ALLGATHER_COST = 20 SHARD_COST = 5 STEP_PENALTY = 6 -NAN = 'nan' +NAN = "nan" class DimSpec: - ''' + """ Sharding spec for single dimension of the sharded tensor describe the sharding dimension of logical device mesh and give a method to compute the difference between them. This class is used internally in ShardingSpec. @@ -21,7 +21,7 @@ class DimSpec: Argument: shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. Otherwise, the element in shard_list means the data will be sharded in that dimension. - ''' + """ def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 @@ -33,41 +33,40 @@ def __eq__(self, other): def __repr__(self): if self.is_replica: - return 'R' - target = 'S' + return "R" + target = "S" for dim in self.shard_list: target += str(dim) return target def _convert_str_to_shard_list(self, str_spec): - ''' + """ Convert str_spec into shard_list. Argument: str_spec(str): dim spec in str type. - ''' + """ - if str_spec == 'R': + if str_spec == "R": return [] - if str_spec == 'S0': + if str_spec == "S0": return [0] - if str_spec == 'S1': + if str_spec == "S1": return [1] - if str_spec == 'S01': + if str_spec == "S01": return [0, 1] def build_difference_2d_dict(self): - ''' + """ Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. - ''' + """ - source_spec_list = ['R', 'S0', 'S1', 'S01'] - target_spec_list = ['R', 'S0', 'S1', 'S01'] + source_spec_list = ["R", "S0", "S1", "S01"] + target_spec_list = ["R", "S0", "S1", "S01"] difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - legal_sharding_dims = [] spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = self._convert_str_to_shard_list(target_spec) @@ -77,14 +76,17 @@ def build_difference_2d_dict(self): difference = 0 # all_gather(source) -> target - elif len(source_shard_list - ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list + ): difference = ALLGATHER_COST # shard(source) -> target - elif len(source_shard_list) == len( - target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ - -1] not in source_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) - 1 + and source_shard_list == target_shard_list[:-1] + and target_shard_list[-1] not in source_shard_list + ): difference = SHARD_COST # S1 -> S0 or S0 -> S1 @@ -115,7 +117,7 @@ def build_difference_2d_dict(self): self.difference_dict = difference_dict def dim_diff(self, other): - ''' + """ The difference between two _DimSpec. Argument: @@ -131,13 +133,13 @@ def dim_diff(self, other): Output: 5 - ''' + """ difference = self.difference_dict[(str(self), str(other))] return difference class ShardingSpec: - ''' + """ Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like [R, R, S0, S1], which means @@ -145,23 +147,27 @@ class ShardingSpec: dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, and the value of the key describe which logical axis will be sharded in that dimension. sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. - ''' + """ - def __init__(self, - dim_size: int, - dim_partition_dict: Dict[int, List[int]] = None, - sharding_sequence: List[DimSpec] = None): + def __init__( + self, dim_size: int, dim_partition_dict: Dict[int, List[int]] = None, sharding_sequence: List[DimSpec] = None + ): self.dims = dim_size self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence if self.sharding_sequence is None: - assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' - self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=self.dims, - dim_partition_dict=self.dim_partition_dict) + assert ( + self.dim_partition_dict is not None + ), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object." + self.dim_partition_dict = merge_same_dim_mesh_list( + dim_size=self.dims, dim_partition_dict=self.dim_partition_dict + ) self.sharding_sequence = self.convert_dict_to_shard_sequence() elif self.dim_partition_dict is None: - assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + assert ( + self.sharding_sequence is not None + ), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object." self.dim_partition_dict = self.convert_shard_sequence_to_dict() self._sanity_check() @@ -169,31 +175,32 @@ def __init__(self, def _sanity_check(self): if len(self.sharding_sequence) > self.dims: raise ShardingOutOfIndexError( - f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.') + f"sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}." + ) if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims: raise ShardingOutOfIndexError( - f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.' + f"the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}." ) def __repr__(self): res_list = ["ShardingSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) - return ' '.join(res_list) + return " ".join(res_list) def convert_dict_to_shard_sequence(self): - ''' + """ Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence. - ''' + """ sharding_sequence = [DimSpec([])] * self.dims for dim, shard_list in self.dim_partition_dict.items(): sharding_sequence[dim] = DimSpec(shard_list) return sharding_sequence def convert_shard_sequence_to_dict(self): - ''' + """ Convert sharding_sequence into dim_partition_dict. - ''' + """ new_dim_partition_dict = {} for index, dim_spec in enumerate(self.sharding_sequence): if not dim_spec.is_replica: @@ -203,7 +210,7 @@ def convert_shard_sequence_to_dict(self): return new_dim_partition_dict def spec_diff(self, other): - ''' + """ This function is a naive version of difference computation. It just simply accumulates difference every dimension between the pair of sharding sequence. @@ -228,9 +235,10 @@ def spec_diff(self, other): Return: difference(int): Difference between two ShardingSpec. - ''' + """ assert len(self.sharding_sequence) == len( - other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + other.sharding_sequence + ), f"Cannot compare difference for two sharding specs with different length." difference = 0 for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): difference += orig_dim_spec.dim_diff(other_dim_spec) diff --git a/colossalai/tensor/d_tensor/utils.py b/colossalai/tensor/d_tensor/utils.py index fc22b990d879..8f0081246fb3 100644 --- a/colossalai/tensor/d_tensor/utils.py +++ b/colossalai/tensor/d_tensor/utils.py @@ -7,7 +7,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]: - ''' + """ This method is used to compute the communication cost for a given layout and comm_spec. For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to @@ -18,7 +18,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals comm_spec: the comm_spec to instruct the communication operation. forward_only: if it is True, we will just count the forward communication cost. If it is False, we will count both forward and backward communication cost. - ''' + """ comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1) device_mesh = layout.device_mesh comm_pattern = comm_spec.comm_pattern diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index e37859bac0c3..1fe99cd89a4e 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -36,6 +36,7 @@ class ColoParamOpHookManager: Manage your param op hooks. It only has static methods. The only static method you should call is ``use_hooks(*hooks)``. """ + hooks: Tuple[ColoParamOpHook, ...] = tuple() @staticmethod @@ -99,7 +100,6 @@ def has_hook() -> bool: class PreFwdPostBwd(torch.autograd.Function): - @staticmethod def forward(ctx, params, *args): ctx.params = params @@ -112,7 +112,6 @@ def backward(ctx, *grads): class PostFwdPreBwd(torch.autograd.Function): - @staticmethod def forward(ctx, params, args): ctx.params = params diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index b837333a2388..409561b3a26b 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -13,7 +13,7 @@ from .comm_spec import * -__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options'] +__all__ = ["ShapeConsistencyManager", "ShapeConsistencyOptions", "set_shape_consistency_options"] @dataclass @@ -21,16 +21,17 @@ class ShapeConsistencyOptions: """ ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency. """ + # TODO: shape consistency option is not implemented yet - pass def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor: shape_consistency_manager = ShapeConsistencyManager() global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) with torch.no_grad(): - global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec, - global_sharding_spec) + global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( + distributed_tensor, sharding_spec, global_sharding_spec + ) return global_tensor @@ -43,7 +44,6 @@ def set_shape_consistency_options(options: ShapeConsistencyOptions): class ShapeConsistencyManager(metaclass=SingletonMeta): - def __init__(self): self._options = None self._forward_only = False @@ -69,9 +69,10 @@ def forward_only(self, value): assert isinstance(value, bool) self._forward_only = value - def get_all_all_gather_spec(self, source_spec: ShardingSpec, - orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: - ''' + def get_all_all_gather_spec( + self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float] + ) -> Dict[ShardingSpec, float]: + """ Get all valid sharding specs from source_spec with single all-gather operation, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. For the all-gather operation, we just care about the S dimension. @@ -99,7 +100,7 @@ def get_all_all_gather_spec(self, source_spec: ShardingSpec, device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,R device_mesh_shape: (4, 4): 0} - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD for target_pair in source_spec.dim_partition_dict.items(): @@ -121,19 +122,20 @@ def get_all_all_gather_spec(self, source_spec: ShardingSpec, comm_pattern, sharding_spec=source_spec, gather_dim=gather_dim, - # shard_dim will be used during backward + # shard_dim will be used during backward shard_dim=gather_dim, logical_process_axis=logical_process_axis, - forward_only=self.forward_only) + forward_only=self.forward_only, + ) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -141,9 +143,10 @@ def get_all_all_gather_spec(self, source_spec: ShardingSpec, pass return valid_spec_dict - def get_all_all_to_all_spec(self, source_spec: ShardingSpec, - orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: - ''' + def get_all_all_to_all_spec( + self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float] + ) -> Dict[ShardingSpec, float]: + """ Get all valid sharding specs from source_spec with single all-to-all operation, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. For the all-to-all operation, we just care about the pairs containing S dimension. @@ -173,7 +176,7 @@ def get_all_all_to_all_spec(self, source_spec: ShardingSpec, device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD tensor_dims = len(source_spec.entire_shape) @@ -214,12 +217,14 @@ def get_all_all_to_all_spec(self, source_spec: ShardingSpec, gather_dim = b_index shard_dim = f_index logical_process_axis = b_target_pair[1][-1] - comm_spec = CommSpec(comm_pattern, - sharding_spec=source_spec, - gather_dim=gather_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis, - forward_only=self.forward_only) + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only, + ) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() @@ -238,9 +243,9 @@ def get_all_all_to_all_spec(self, source_spec: ShardingSpec, # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -250,7 +255,7 @@ def get_all_all_to_all_spec(self, source_spec: ShardingSpec, return valid_spec_dict def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): - ''' + """ Get all valid sharding specs from source_spec with single shard operation, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. For the sharding operation, we just care about legal sharding dimensions. @@ -280,7 +285,7 @@ def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD @@ -308,21 +313,23 @@ def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec shard_dim = index logical_process_axis = shard_list[-1] - comm_spec = CommSpec(comm_pattern, - sharding_spec=source_spec, - gather_dim=shard_dim, - shard_dim=shard_dim, - logical_process_axis=logical_process_axis, - forward_only=self.forward_only) + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis, + forward_only=self.forward_only, + ) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -330,14 +337,15 @@ def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): pass return valid_spec_dict - def get_all_mix_gather_spec(self, source_spec: ShardingSpec, - orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: - ''' + def get_all_mix_gather_spec( + self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float] + ) -> Dict[ShardingSpec, float]: + """ S0S1 -> RR S1S0 -> RR S01R -> RR RS01 -> RR - ''' + """ valid_spec_dict = {} comm_pattern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD tensor_dims = len(source_spec.entire_shape) @@ -362,19 +370,21 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec, b_target_pair = (b_index, []) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - comm_spec = CommSpec(comm_pattern, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=self.forward_only, - mix_gather=True) + comm_spec = CommSpec( + comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=self.forward_only, + mix_gather=True, + ) cost_dict = comm_spec.get_comm_cost() new_dim_partition_dict = {} # generate new sharding spec try: - new_sharding_spec = ShardingSpec(source_spec.device_mesh, - source_spec.entire_shape, - dim_partition_dict=new_dim_partition_dict) + new_sharding_spec = ShardingSpec( + source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict + ) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) @@ -384,7 +394,7 @@ def get_all_mix_gather_spec(self, source_spec: ShardingSpec, return valid_spec_dict def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: - ''' + """ Get all valid sharding specs from source_spec with one step transform, and accumulate communication cost on origin cost which will finally be used in auto sharding solver. Note: @@ -398,7 +408,7 @@ def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_d Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. - ''' + """ valid_spec_dict = {} valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict)) valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict)) @@ -545,18 +555,22 @@ def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): # the first forward comm action will not discard input fwd_action, comm_spec = action_spec_pair - fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, - fwd_peak_numel) if idx == 0 else fwd_action( - comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + fwd_alloc_numel, fwd_peak_numel = ( + fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) + if idx == 0 + else fwd_action(comm_spec, True, fwd_alloc_numel, fwd_peak_numel) + ) # analyze memory footprint for backward comm actions sequence bwd_alloc_numel = 0 bwd_peak_numel = 0 for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): bwd_action, comm_spec = action_spec_pair - bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel, - bwd_peak_numel) if idx == 0 else bwd_action( - comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + bwd_alloc_numel, bwd_peak_numel = ( + bwd_action(comm_spec, False, bwd_alloc_numel, bwd_peak_numel) + if idx == 0 + else bwd_action(comm_spec, True, bwd_alloc_numel, bwd_peak_numel) + ) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) @@ -564,9 +578,10 @@ def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int return TrainCycleItem(fwd_mem, bwd_mem, total_mem) - def shape_consistency(self, source_spec: ShardingSpec, - target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: - ''' + def shape_consistency( + self, source_spec: ShardingSpec, target_spec: ShardingSpec + ) -> Tuple[List[ShardingSpec], List[CommSpec], float]: + """ This method will find a path to transform source_spec to target_spec with a greedy algorithm. The basic idea is: @@ -623,9 +638,9 @@ def shape_consistency(self, source_spec: ShardingSpec, CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0), CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)] total_cost: 12294.402000000002 - ''' + """ MAX_TRANSFORM_STEPS = 20 - total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0} + total_cost_dict = {"forward": 0, "backward": 0, "total": 0} total_steps = 0 transform_path = [] comm_action_sequence = [] @@ -672,7 +687,7 @@ def shape_consistency(self, source_spec: ShardingSpec, raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor: - ''' + """ Apply target_spec to tensor with source sharding spec, the transform path is generated by the shape_consistency method. @@ -729,7 +744,7 @@ def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSp [1.], [3.], [3.]]) - ''' + """ _, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec) for comm_spec in comm_action_sequence: tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec) diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index e594fd297dc4..b78ef6d97dd4 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -8,16 +8,16 @@ from .utils import merge_same_dim_mesh_list -__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] +__all__ = ["_DimSpec", "ShardingException", "ShardingSpec"] ALLGATHER_COST = 20 SHARD_COST = 5 STEP_PENALTY = 6 -NAN = 'nan' +NAN = "nan" class _DimSpec: - ''' + """ Sharding spec for single dimension of the sharded tensor describe the sharding dimension of logical device mesh and give a method to compute the difference between them. This class is used internally in ShardingSpec. @@ -25,7 +25,7 @@ class _DimSpec: Argument: shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. Otherwise, the element in shard_list means the data will be sharded in that dimension. - ''' + """ def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 @@ -37,41 +37,40 @@ def __eq__(self, other): def __repr__(self): if self.is_replica: - return 'R' - target = 'S' + return "R" + target = "S" for dim in self.shard_list: target += str(dim) return target def _convert_str_to_shard_list(self, str_spec): - ''' + """ Convert str_spec into shard_list. Argument: str_spec(str): dim spec in str type. - ''' + """ - if str_spec == 'R': + if str_spec == "R": return [] - if str_spec == 'S0': + if str_spec == "S0": return [0] - if str_spec == 'S1': + if str_spec == "S1": return [1] - if str_spec == 'S01': + if str_spec == "S01": return [0, 1] def build_difference_2d_dict(self): - ''' + """ Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. - ''' + """ - source_spec_list = ['R', 'S0', 'S1', 'S01'] - target_spec_list = ['R', 'S0', 'S1', 'S01'] + source_spec_list = ["R", "S0", "S1", "S01"] + target_spec_list = ["R", "S0", "S1", "S01"] difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - legal_sharding_dims = [] spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = self._convert_str_to_shard_list(target_spec) @@ -81,14 +80,17 @@ def build_difference_2d_dict(self): difference = 0 # all_gather(source) -> target - elif len(source_shard_list - ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list + ): difference = ALLGATHER_COST # shard(source) -> target - elif len(source_shard_list) == len( - target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ - -1] not in source_shard_list: + elif ( + len(source_shard_list) == len(target_shard_list) - 1 + and source_shard_list == target_shard_list[:-1] + and target_shard_list[-1] not in source_shard_list + ): difference = SHARD_COST # S1 -> S0 or S0 -> S1 @@ -119,7 +121,7 @@ def build_difference_2d_dict(self): self.difference_dict = difference_dict def difference(self, other): - ''' + """ The difference between two _DimSpec. Argument: @@ -135,7 +137,7 @@ def difference(self, other): Output: 5 - ''' + """ difference = self.difference_dict[(str(self), str(other))] return difference @@ -157,7 +159,7 @@ class ShardingNotDivisibleError(ShardingSpecException): class ShardingSpec: - ''' + """ Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong to, the entire shape of the tensor before sharded, and the sharding sequence looks like [R, R, S0, S1]. @@ -168,13 +170,11 @@ class ShardingSpec: dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, and the value of the key describe which logical axis will be sharded in that dimension. sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. - ''' + """ - def __init__(self, - device_mesh: DeviceMesh, - entire_shape: torch.Size, - dim_partition_dict=None, - sharding_sequence=None): + def __init__( + self, device_mesh: DeviceMesh, entire_shape: torch.Size, dim_partition_dict=None, sharding_sequence=None + ): self.device_mesh = device_mesh if isinstance(entire_shape, (list, tuple)): @@ -183,12 +183,17 @@ def __init__(self, self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence if self.sharding_sequence is None: - assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' - self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape), - dim_partition_dict=self.dim_partition_dict) + assert ( + self.dim_partition_dict is not None + ), f"dim_partition_dict should not be None, if sharding_sequence is NoneType object." + self.dim_partition_dict = merge_same_dim_mesh_list( + dim_size=len(entire_shape), dim_partition_dict=self.dim_partition_dict + ) self.convert_dict_to_shard_sequence() elif self.dim_partition_dict is None: - assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + assert ( + self.sharding_sequence is not None + ), f"sharding_sequence should not be None, if dim_partition_dict is NoneType object." self.convert_shard_sequence_to_dict() self._sanity_check() @@ -196,7 +201,7 @@ def __repr__(self): res_list = ["DistSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}") - return ' '.join(res_list) + return " ".join(res_list) def _sanity_check(self): # make sure all axes in logical device mesh only be used once @@ -207,7 +212,8 @@ def _sanity_check(self): dim_check_list.remove(element) else: raise DuplicatedShardingDimensionError( - f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}." + ) # make sure that the dimension is not out of index for dim in self.dim_partition_dict.keys(): @@ -226,22 +232,22 @@ def _sanity_check(self): if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( - f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices." ) def convert_dict_to_shard_sequence(self): - ''' + """ Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence. - ''' + """ sharding_sequence = [_DimSpec([])] * len(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): sharding_sequence[dim] = _DimSpec(shard_list) self.sharding_sequence = sharding_sequence def convert_shard_sequence_to_dict(self): - ''' + """ Convert sharding_sequence into dim_partition_dict. - ''' + """ new_dim_partition_dict = {} for index, dim_spec in enumerate(self.sharding_sequence): if not dim_spec.is_replica: @@ -251,7 +257,7 @@ def convert_shard_sequence_to_dict(self): self.dim_partition_dict = new_dim_partition_dict def sharding_sequence_difference(self, other): - ''' + """ This function is a naive version of difference computation. It just simply accumulates difference every dimension between the pair of sharding sequence. @@ -276,21 +282,22 @@ def sharding_sequence_difference(self, other): Return: difference(int): Difference between two ShardingSpec. - ''' + """ assert len(self.sharding_sequence) == len( - other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + other.sharding_sequence + ), f"Cannot compare difference for two sharding specs with different length." difference = 0 for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): difference += orig_dim_spec.difference(other_dim_spec) return difference def get_sharded_shape_per_device(self): - sharded_shape = list(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) - assert sharded_shape[ - dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + assert ( + sharded_shape[dim] % shard_partitions == 0 + ), f"Cannot shard dimension {dim} into {shard_partitions} partitions." sharded_shape[dim] //= shard_partitions return torch.Size(sharded_shape) diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py index e7d51d099e02..19dde8febf84 100644 --- a/colossalai/tensor/utils.py +++ b/colossalai/tensor/utils.py @@ -7,7 +7,7 @@ def all_gather_simulator(target_pair): - ''' + """ Simulating all-gather operation, analyze the communication cost and simulate the influence of the DimSpec. @@ -19,7 +19,7 @@ def all_gather_simulator(target_pair): Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element describes which logical axis will be sharded in that dimension. - ''' + """ _, shard_list = target_pair new_shard_list = shard_list[:-1] @@ -27,7 +27,7 @@ def all_gather_simulator(target_pair): def all_to_all_simulator(f_target_pair, b_target_pair): - ''' + """ Simulating all-to-all operation, analyze the communication cost and simulate the influence of the DimSpec. @@ -47,7 +47,7 @@ def all_to_all_simulator(f_target_pair, b_target_pair): Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element describes which logical axis will be sharded in that dimension. - ''' + """ _, f_shard_list = f_target_pair _, b_shard_list = b_target_pair if not len(b_shard_list): @@ -61,7 +61,7 @@ def all_to_all_simulator(f_target_pair, b_target_pair): def shard_simulator(target_pair, legal_sharding_dims): - ''' + """ Simulating shard operation, analyze the communication cost(always ZERO) and simulate the influence of the DimSpec. @@ -78,7 +78,7 @@ def shard_simulator(target_pair, legal_sharding_dims): Argument: target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, and the second element describes which logical axis will be sharded in that dimension. - ''' + """ _, shard_list = target_pair shard_list_list = [] for dim in legal_sharding_dims: @@ -91,7 +91,7 @@ def shard_simulator(target_pair, legal_sharding_dims): def mix_gather_simulator(f_target_pair, b_target_pair): - ''' + """ Assume index of f and b target pairs are 'f' and 'b' S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0) S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1) @@ -99,7 +99,7 @@ def mix_gather_simulator(f_target_pair, b_target_pair): RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1) S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0) RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0) - ''' + """ if f_target_pair[1] and b_target_pair[1]: leading_dim = b_target_pair[1] > f_target_pair[1] return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)] @@ -118,7 +118,7 @@ def mix_gather_simulator(f_target_pair, b_target_pair): # The function is credited to PyTorch Team def named_params_with_colotensor( module: nn.Module, - prefix: str = '', + prefix: str = "", recurse: bool = True, ) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: r"""Returns an iterator over module parameters (together with the @@ -154,7 +154,7 @@ def named_params_with_colotensor( for name, val in vars(mod).items(): if isinstance(val, ColoTensor) and val not in memo: memo.add(val) - name = mod_prefix + ('.' if mod_prefix else '') + name + name = mod_prefix + ("." if mod_prefix else "") + name yield name, val # find all nn.Parameters @@ -169,15 +169,16 @@ def _convert_tensor(tensor: torch.Tensor) -> ColoTensor: def convert_parameter(module: torch.nn.Module, param_name: str): # Perform some validation first. if not hasattr(module, param_name): - raise ValueError(f'module: {module} does not have parameter with name: {param_name}') + raise ValueError(f"module: {module} does not have parameter with name: {param_name}") tensor = getattr(module, param_name) if not isinstance(tensor, torch.Tensor): raise ValueError( - f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') + f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" + ) if not tensor.is_contiguous(): - raise ValueError(f'param: {param_name} is not a contiguous Tensor') + raise ValueError(f"param: {param_name} is not a contiguous Tensor") st = _convert_tensor(tensor) @@ -193,9 +194,9 @@ def convert_parameter(module: torch.nn.Module, param_name: str): def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: - ''' + """ This method is used to convert the negative dim value to positive. - ''' + """ dims_to_convert = [] for dim, mesh_list in dim_partition_dict.items(): if dim < 0: @@ -207,13 +208,13 @@ def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]: - ''' + """ This method is used to merge the different key value which points to same physical position. For example: dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position. In this method, above dim_partition_dict will be converted to {1: [0, 1]} - ''' + """ converted_dim_partition_dict = {} for dim, mesh_list in dim_partition_dict.items(): if dim < 0: diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index 0db33361c6a0..c6956e81fbde 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -19,7 +19,19 @@ ) __all__ = [ - 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', - 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', - 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal', 'assert_hf_output_close' + "assert_equal", + "assert_not_equal", + "assert_close", + "assert_close_loose", + "assert_equal_in_group", + "parameterize", + "rerun_on_exception", + "rerun_if_address_is_in_use", + "skip_if_not_enough_gpus", + "free_port", + "spawn", + "clear_cache_before_run", + "run_on_environment_flag", + "check_state_dict_equal", + "assert_hf_output_close", ] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 8d9ec8ab5f35..816bc0d7b6d7 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -9,20 +9,22 @@ def assert_equal(a: Tensor, b: Tensor): - assert torch.all(a == b), f'expected a and b to be equal but they are not, {a} vs {b}' + assert torch.all(a == b), f"expected a and b to be equal but they are not, {a} vs {b}" def assert_not_equal(a: Tensor, b: Tensor): - assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}' + assert not torch.all(a == b), f"expected a and b to be not equal but they are, {a} vs {b}" def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): - assert_close(a, - b, - rtol=rtol, - atol=atol, - msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ - dtype: {a.dtype} vs {b.dtype}") + assert_close( + a, + b, + rtol=rtol, + atol=atol, + msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ + dtype: {a.dtype} vs {b.dtype}", + ) def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): @@ -35,12 +37,13 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}' + assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): - assert len(list(d1.keys())) == len(list(d2.keys())), \ - f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" + assert len(list(d1.keys())) == len( + list(d2.keys()) + ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" for k, v1 in d1.items(): assert k in d2 v2 = d2[k] @@ -86,12 +89,9 @@ def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_devic assert v1 == v2, f"{v1} not equals to {v2}" -def assert_hf_output_close(out1: Any, - out2: Any, - ignore_keys: List[str] = None, - track_name: str = "", - atol=1e-5, - rtol=1e-5): +def assert_hf_output_close( + out1: Any, out2: Any, ignore_keys: List[str] = None, track_name: str = "", atol=1e-5, rtol=1e-5 +): """ Check if two outputs from huggingface are equal. @@ -108,23 +108,17 @@ def assert_hf_output_close(out1: Any, for k in out1.keys(): if ignore_keys is not None and k in ignore_keys: continue - assert_hf_output_close(out1[k], - out2[k], - track_name=f"{track_name}.{k}", - ignore_keys=ignore_keys, - atol=atol, - rtol=rtol) + assert_hf_output_close( + out1[k], out2[k], track_name=f"{track_name}.{k}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + ) elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): # if two values are list # we recursively check the elements assert len(out1) == len(out2) for i in range(len(out1)): - assert_hf_output_close(out1[i], - out2[i], - track_name=f"{track_name}.{i}", - ignore_keys=ignore_keys, - atol=atol, - rtol=rtol) + assert_hf_output_close( + out1[i], out2[i], track_name=f"{track_name}.{i}", ignore_keys=ignore_keys, atol=atol, rtol=rtol + ) elif isinstance(out1, Tensor) and isinstance(out2, Tensor): if out1.shape != out2.shape: raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") diff --git a/colossalai/testing/pytest_wrapper.py b/colossalai/testing/pytest_wrapper.py index 6a80e1dcc548..b1e82b469c96 100644 --- a/colossalai/testing/pytest_wrapper.py +++ b/colossalai/testing/pytest_wrapper.py @@ -33,13 +33,14 @@ def test_for_something(): import pytest except ImportError: raise ImportError( - 'This function requires `pytest` to be installed, please do `pip install pytest` and try again.') + "This function requires `pytest` to be installed, please do `pip install pytest` and try again." + ) assert isinstance(name, str) - flag = os.environ.get(name.upper(), '0') + flag = os.environ.get(name.upper(), "0") - reason = f'Environment variable {name} is {flag}' - if flag == '1': + reason = f"Environment variable {name} is {flag}" + if flag == "1": return pytest.mark.skipif(False, reason=reason) else: return pytest.mark.skipif(True, reason=reason) diff --git a/colossalai/testing/random.py b/colossalai/testing/random.py index ad6d24a4b94b..4525dff3fe80 100644 --- a/colossalai/testing/random.py +++ b/colossalai/testing/random.py @@ -11,7 +11,7 @@ def seed_all(seed, cuda_deterministic=False): if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - if cuda_deterministic: # slower, more reproducible + if cuda_deterministic: # slower, more reproducible torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index a4370a8d4933..fdbda9a598bf 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -55,7 +55,6 @@ def say_something(person, msg): """ def _wrapper(func): - def _execute_function_by_param(**kwargs): for val in values: arg_map = {argument: val} @@ -120,11 +119,11 @@ def _match_lines(lines, pattern): return False def _wrapper(func): - def _run_until_success(*args, **kwargs): try_count = 0 - assert max_try is None or isinstance(max_try, int), \ - f'Expected max_try to be None or int, but got {type(max_try)}' + assert max_try is None or isinstance( + max_try, int + ), f"Expected max_try to be None or int, but got {type(max_try)}" while max_try is None or try_count < max_try: try: @@ -132,14 +131,14 @@ def _run_until_success(*args, **kwargs): ret = func(*args, **kwargs) return ret except exception_type as e: - error_lines = str(e).split('\n') + error_lines = str(e).split("\n") if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)): - print('Exception is caught, retrying...') + print("Exception is caught, retrying...") # when pattern is not specified, we always skip the exception # when pattern is specified, we only skip when pattern is matched continue else: - print('Maximum number of attempts is reached or pattern is not matched, no more retrying...') + print("Maximum number of attempts is reached or pattern is not matched, no more retrying...") raise e # Override signature @@ -198,7 +197,6 @@ def test_something(): """ def _wrap_func(f): - def _execute_by_gpu_num(*args, **kwargs): num_avail_gpu = torch.cuda.device_count() if num_avail_gpu >= min_gpus: @@ -263,7 +261,6 @@ def test_something(): """ def _wrap_func(f): - def _clear_cache(*args, **kwargs): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 5226f688b43b..3ec39b949a23 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -13,20 +13,20 @@ from .timer import MultiTimer, Timer __all__ = [ - 'conditional_context', - 'get_current_device', - 'synchronize', - 'empty_cache', - 'set_to_cuda', - 'Timer', - 'MultiTimer', - 'multi_tensor_applier', - 'TensorDetector', - 'ensure_path_exists', - 'disposable', - '_cast_float', - 'free_storage', - 'set_seed', - 'is_ddp_ignored', - 'set_device', + "conditional_context", + "get_current_device", + "synchronize", + "empty_cache", + "set_to_cuda", + "Timer", + "MultiTimer", + "multi_tensor_applier", + "TensorDetector", + "ensure_path_exists", + "disposable", + "_cast_float", + "free_storage", + "set_seed", + "is_ddp_ignored", + "set_device", ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 8c769c5b13c0..c43caaff4806 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -28,7 +28,7 @@ def conditional_context(context_manager, enable=True): def is_ddp_ignored(p): - return getattr(p, '_ddp_to_ignore', False) + return getattr(p, "_ddp_to_ignore", False) def disposable(func: Callable) -> Callable: diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py index 6b5d17cf04e7..6bfb08d1f04a 100644 --- a/colossalai/utils/cuda.py +++ b/colossalai/utils/cuda.py @@ -29,9 +29,9 @@ def get_current_device() -> torch.device: If cuda available, return gpu, otherwise return cpu. """ if torch.cuda.is_available(): - return torch.device(f'cuda:{torch.cuda.current_device()}') + return torch.device(f"cuda:{torch.cuda.current_device()}") else: - return torch.device('cpu') + return torch.device("cpu") def synchronize(): diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py index 21bc530934d3..4eee4fbc0eee 100644 --- a/colossalai/utils/model/utils.py +++ b/colossalai/utils/model/utils.py @@ -27,19 +27,18 @@ def call_to_str(base, *args, **kwargs): Returns: str: A string representation of base(*args, **kwargs) """ - name = f'{base}(' + name = f"{base}(" if args: - name += ', '.join(repr(arg) for arg in args) + name += ", ".join(repr(arg) for arg in args) if kwargs: - name += ', ' + name += ", " if kwargs: - name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items()) - name += ')' + name += ", ".join(f"{key}={repr(arg)}" for key, arg in kwargs.items()) + name += ")" return name class InsertPostInitMethodToModuleSubClasses(object): - def __init__(self, default_dtype: Optional[torch.dtype] = None): self._old_default_dtype = None self._default_dtype = default_dtype @@ -53,7 +52,6 @@ def __enter__(self): torch.set_default_dtype(self._default_dtype) def preprocess_after(f): - @functools.wraps(f) def wrapper(module: torch.nn.Module, *args, **kwargs): f(module, *args, **kwargs) @@ -74,7 +72,7 @@ def _init_subclass(cls, **kwargs): substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set()) # holding on to the current __init__subclass__ for exit - torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) + torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) @@ -82,12 +80,11 @@ def _init_subclass(cls, **kwargs): return self def __exit__(self, exc_type, exc_value, traceback): - if self._default_dtype is not None: torch.set_default_dtype(self._old_default_dtype) def _disable_class(cls): - if not hasattr(cls, '_old_init'): + if not hasattr(cls, "_old_init"): raise AttributeError( f"_old_init is not found in the {cls.__name__}, please make sure that you have imported {cls.__name__} before entering the context." ) @@ -97,7 +94,7 @@ def _disable_class(cls): substitute_init_recursively(torch.nn.modules.module.Module, _disable_class, set()) # Replace .__init__() for future subclasses of torch.nn.Module - torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) + torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass self._post_context_exec() # Now that we cleaned up the metaclass injection, raise the exception. diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py index 6456dfb905b0..1b75448bdd3c 100644 --- a/colossalai/utils/moe.py +++ b/colossalai/utils/moe.py @@ -19,8 +19,8 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]] """ epsize_param_dict = dict() for param in model.parameters(): - if not hasattr(param, 'moe_info'): - ep_size = 1 # set ep_size to 1 for dp parameters + if not hasattr(param, "moe_info"): + ep_size = 1 # set ep_size to 1 for dp parameters else: ep_size = param.moe_info.ep_size if ep_size not in epsize_param_dict: @@ -37,7 +37,6 @@ def sync_moe_model_param(model: nn.Module): model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. """ if is_using_ddp(): - param_dict = get_moe_epsize_param_dict(model) # synchronize the parameters whose dp_group is the whole world diff --git a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py index 2b6de5fe1f3c..750c2a32da34 100644 --- a/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py +++ b/colossalai/utils/multi_tensor_apply/multi_tensor_apply.py @@ -25,7 +25,9 @@ def check_avail(self): raise RuntimeError( "Attempted to call MultiTensorApply method, but MultiTensorApply " "is not available, possibly because Apex was installed without " - "--cpp_ext --cuda_ext. Original import error message:", MultiTensorApply.import_err) + "--cpp_ext --cuda_ext. Original import error message:", + MultiTensorApply.import_err, + ) def __call__(self, op, noop_flag_buffer, tensor_lists, *args): self.check_avail() diff --git a/colossalai/utils/rank_recorder/README.md b/colossalai/utils/rank_recorder/README.md index da8a6039d543..cad6c1fddd71 100644 --- a/colossalai/utils/rank_recorder/README.md +++ b/colossalai/utils/rank_recorder/README.md @@ -1,7 +1,7 @@ # Rank Recorder This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualize the json file easily. -Before using the tool, you should ensure dist.is_initialized() return true before exit of program. +Before using the tool, you should ensure dist.is_initialized() return true before exit of program. ## Usage @@ -58,10 +58,10 @@ def worker(rank): with recorder("calc_1(x100)", rank) as r: calc(100, 100) - + with recorder("calc_2(x400)", rank) as r: calc(400, 400) - + with recorder("calc_2(x200)", rank) as r: calc(200, 200) @@ -69,4 +69,4 @@ if __name__ == "__main__": mp.spawn(worker, nprocs=WORLD_SIZE) ``` -run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. \ No newline at end of file +run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. diff --git a/colossalai/utils/rank_recorder/__init__.py b/colossalai/utils/rank_recorder/__init__.py index 1274d0e7dbc5..1d347075a8ce 100644 --- a/colossalai/utils/rank_recorder/__init__.py +++ b/colossalai/utils/rank_recorder/__init__.py @@ -1,3 +1,3 @@ from colossalai.utils.rank_recorder.rank_recorder import recorder -__all__ = ["recorder"] \ No newline at end of file +__all__ = ["recorder"] diff --git a/colossalai/utils/rank_recorder/rank_recorder.py b/colossalai/utils/rank_recorder/rank_recorder.py index 40bb7e184a12..1cb9169125a1 100644 --- a/colossalai/utils/rank_recorder/rank_recorder.py +++ b/colossalai/utils/rank_recorder/rank_recorder.py @@ -1,18 +1,15 @@ -import time -from typing import List, Dict +import atexit import json import os -import time import shutil -import atexit +import time +from typing import Dict, List +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt import torch import torch.distributed as dist -import json -import matplotlib.pyplot as plt -import matplotlib.colors as mcolors - cmap = list(mcolors.TABLEAU_COLORS.values()) LOG_FOLDER = "record.log" @@ -20,7 +17,6 @@ class Event: - def __init__(self, start: int, end: int, name: str, rank: int) -> None: self.start = start self.end = end @@ -29,16 +25,15 @@ def __init__(self, start: int, end: int, name: str, rank: int) -> None: class Recorder: - def __init__(self) -> None: self.rank_to_history: Dict[int, List[Event]] = {} self.base_time = time.time() self.temp_event = None - self.export_format = 'png' - self.export_name = 'test' + self.export_format = "png" + self.export_name = "test" self.dpi = 500 - self.theme = 'dark_background' + self.theme = "dark_background" self.figure_width = 30 self.figure_height = 10 self.legend_fontsize = 16 @@ -84,18 +79,18 @@ def __exit__(self, *args): def dump_record(self): rank = dist.get_rank() rank_to_history = self.rank_to_history - records = {'base_time': self.base_time, 'content': {}} + records = {"base_time": self.base_time, "content": {}} for record_rank in rank_to_history: history = rank_to_history[record_rank] recs = [] for event in history: - rec = {'start': event.start, 'end': event.end, 'name': event.name} + rec = {"start": event.start, "end": event.end, "name": event.name} recs.append(rec) - records['content'][record_rank] = recs + records["content"][record_rank] = recs - dump_name = f'{rank}.json' + dump_name = f"{rank}.json" dump_path = os.path.join(LOG_FOLDER, dump_name) - with open(dump_path, 'w', encoding='utf-8') as f: + with open(dump_path, "w", encoding="utf-8") as f: json.dump(records, f, ensure_ascii=False) def merge_recode(self): @@ -117,24 +112,22 @@ def merge_recode(self): logs_path = [os.path.join(LOG_FOLDER, file) for file in os.listdir(LOG_FOLDER)] recoders = {} for path in logs_path: - with open(path, 'r', encoding='utf-8') as f: + with open(path, "r", encoding="utf-8") as f: recs = json.load(f) - for record_rank in recs['content']: - history = recs['content'][record_rank] + for record_rank in recs["content"]: + history = recs["content"][record_rank] recoders[record_rank] = [] for rec in history: - recoders[record_rank].append({ - 'start': rec['start'] - base_time, - 'end': rec['end'] - base_time, - 'name': rec['name'] - }) + recoders[record_rank].append( + {"start": rec["start"] - base_time, "end": rec["end"] - base_time, "name": rec["name"]} + ) shutil.rmtree(LOG_FOLDER) - with open(self.export_name + '.json', 'w', encoding='utf-8') as f: + with open(self.export_name + ".json", "w", encoding="utf-8") as f: json.dump(recoders, f, ensure_ascii=False) def visualize_record(self): - with open(self.export_name + '.json', 'r', encoding='utf-8') as f: + with open(self.export_name + ".json", "r", encoding="utf-8") as f: records = json.load(f) records = dict(records) ranks = list(sorted(records.keys())) @@ -147,9 +140,9 @@ def visualize_record(self): for rank in ranks: rank_records = records[rank] for rec in rank_records: - s = rec['start'] - e = rec['end'] - name = rec['name'] + s = rec["start"] + e = rec["end"] + name = rec["name"] if name not in name_list: name_list[name] = len(name_list) bar = plt.barh(rank, width=e - s, height=self.bar_height, left=s, color=cmap[name_list[name]]) @@ -157,8 +150,8 @@ def visualize_record(self): plots[name] = bar plt.legend(list(plots.values()), list(plots.keys()), loc="upper left", fontsize=self.legend_fontsize) - plt.yticks(ticks=ranks, labels=[f'Device:{rank}' for rank in ranks], fontsize=self.device_fontsize) - plt.grid(axis='x') + plt.yticks(ticks=ranks, labels=[f"Device:{rank}" for rank in ranks], fontsize=self.device_fontsize) + plt.grid(axis="x") plt.savefig("{}.{}".format(self.export_name, self.export_format)) def exit_worker(self): diff --git a/colossalai/utils/tensor_detector/__init__.py b/colossalai/utils/tensor_detector/__init__.py index cafc19b67c5c..c6c68aa4009b 100644 --- a/colossalai/utils/tensor_detector/__init__.py +++ b/colossalai/utils/tensor_detector/__init__.py @@ -1 +1 @@ -from .tensor_detector import TensorDetector +from .tensor_detector import TensorDetector diff --git a/colossalai/utils/tensor_detector/readme.md b/colossalai/utils/tensor_detector/readme.md index d6852ea55b54..455eae18116a 100644 --- a/colossalai/utils/tensor_detector/readme.md +++ b/colossalai/utils/tensor_detector/readme.md @@ -14,7 +14,7 @@ class MLP(nn.Module): super().__init__() self.mlp = nn.Sequential(nn.Linear(64, 8), nn.ReLU(), - nn.Linear(8, 32)) + nn.Linear(8, 32)) def forward(self, x): return self.mlp(x) ``` @@ -125,4 +125,3 @@ Total GPU Memory Allocated on cuda:0 is 14.0 KB This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py and https://github.com/Oldpan/Pytorch-Memory-Utils - diff --git a/colossalai/utils/tensor_detector/tensor_detector.py b/colossalai/utils/tensor_detector/tensor_detector.py index cfcd4e47b4cb..38cf094b8dd0 100644 --- a/colossalai/utils/tensor_detector/tensor_detector.py +++ b/colossalai/utils/tensor_detector/tensor_detector.py @@ -1,21 +1,19 @@ import gc import inspect +from collections import defaultdict +from typing import Optional + import torch import torch.nn as nn -from typing import Optional -from collections import defaultdict LINE_WIDTH = 108 -LINE = '-' * LINE_WIDTH + '\n' - +LINE = "-" * LINE_WIDTH + "\n" -class TensorDetector(): - def __init__(self, - show_info: bool = True, - log: str = None, - include_cpu: bool = False, - module: Optional[nn.Module] = None): +class TensorDetector: + def __init__( + self, show_info: bool = True, log: str = None, include_cpu: bool = False, module: Optional[nn.Module] = None + ): """This class is a detector to detect tensor on different devices. Args: @@ -57,40 +55,39 @@ def get_tensor_mem(self, tensor): def mem_format(self, real_memory_size): # format the tensor memory into a reasonable magnitude if real_memory_size >= 2**30: - return str(real_memory_size / (2**30)) + ' GB' + return str(real_memory_size / (2**30)) + " GB" if real_memory_size >= 2**20: - return str(real_memory_size / (2**20)) + ' MB' + return str(real_memory_size / (2**20)) + " MB" if real_memory_size >= 2**10: - return str(real_memory_size / (2**10)) + ' KB' - return str(real_memory_size) + ' B' + return str(real_memory_size / (2**10)) + " KB" + return str(real_memory_size) + " B" def collect_tensors_state(self): for obj in gc.get_objects(): if torch.is_tensor(obj): # skip cpu tensor when include_cpu is false and the tensor we have collected before - if (not self.include_cpu) and obj.device == torch.device('cpu'): + if (not self.include_cpu) and obj.device == torch.device("cpu"): continue self.detected.append(id(obj)) # skip parameters we had added in __init__ when module is an instance of nn.Module for the first epoch if id(obj) not in self.tensor_info: - name = type(obj).__name__ # after backward, we want to update the records, to show you the change - if isinstance(self.module, nn.Module) and name == 'Parameter': + if isinstance(self.module, nn.Module) and name == "Parameter": if obj.grad is not None: # with grad attached for par_name, param in self.module.named_parameters(): if param.requires_grad and param.grad.equal(obj.grad): - name = par_name + ' (with grad)' + name = par_name + " (with grad)" else: # with no grad attached # there will be no new parameters created during running # so it must be in saved_tensor_info continue # we can also marked common tensors as tensor(with grad) - if name == 'Tensor' and (obj.is_leaf or obj.retains_grad): + if name == "Tensor" and (obj.is_leaf or obj.retains_grad): if obj.grad is not None: - name = name + ' (with grad)' + name = name + " (with grad)" # in fact, common tensor have no grad # unless you set retain_grad() if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]: @@ -111,10 +108,10 @@ def collect_tensors_state(self): self.devices.append(obj.device) def print_tensors_state(self): - template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}' + template_format = "{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}" self.info += LINE - self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') - self.info += '\n' + self.info += template_format.format(" ", "Tensor", "device", "shape", "grad", "dtype", "Mem") + self.info += "\n" self.info += LINE # if a tensor updates this turn, and was recorded before @@ -124,24 +121,30 @@ def print_tensors_state(self): minus = outdated + minus if len(self.order) > 0: for tensor_id in self.order: - self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]), - str(self.tensor_info[tensor_id][1]), - str(tuple(self.tensor_info[tensor_id][2])), - str(self.tensor_info[tensor_id][3]), - str(self.tensor_info[tensor_id][4]), - str(self.tensor_info[tensor_id][5])) - self.info += '\n' + self.info += template_format.format( + "+", + str(self.tensor_info[tensor_id][0]), + str(self.tensor_info[tensor_id][1]), + str(tuple(self.tensor_info[tensor_id][2])), + str(self.tensor_info[tensor_id][3]), + str(self.tensor_info[tensor_id][4]), + str(self.tensor_info[tensor_id][5]), + ) + self.info += "\n" if len(self.order) > 0 and len(minus) > 0: - self.info += '\n' + self.info += "\n" if len(minus) > 0: for tensor_id in minus: - self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]), - str(self.saved_tensor_info[tensor_id][1]), - str(tuple(self.saved_tensor_info[tensor_id][2])), - str(self.saved_tensor_info[tensor_id][3]), - str(self.saved_tensor_info[tensor_id][4]), - str(self.saved_tensor_info[tensor_id][5])) - self.info += '\n' + self.info += template_format.format( + "-", + str(self.saved_tensor_info[tensor_id][0]), + str(self.saved_tensor_info[tensor_id][1]), + str(tuple(self.saved_tensor_info[tensor_id][2])), + str(self.saved_tensor_info[tensor_id][3]), + str(self.saved_tensor_info[tensor_id][4]), + str(self.saved_tensor_info[tensor_id][5]), + ) + self.info += "\n" # deleted the updated tensor self.saved_tensor_info.pop(tensor_id) @@ -152,16 +155,16 @@ def print_tensors_state(self): self.info += LINE self.info += f"Detect Location: {locate_msg}\n" for device in self.devices: - if device == torch.device('cpu'): + if device == torch.device("cpu"): continue gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device)) self.info += f"Total GPU Memory Allocated on {device} is {gpu_mem_alloc}\n" self.info += LINE - self.info += '\n\n' + self.info += "\n\n" if self.show_info: print(self.info) if self.log is not None: - with open(self.log + '.log', 'a') as f: + with open(self.log + ".log", "a") as f: f.write(self.info) def detect(self, include_cpu=False): diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 4b61f4a5ef11..2f61817f0461 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -2,12 +2,12 @@ # -*- encoding: utf-8 -*- import time from typing import Tuple + from .cuda import synchronize class Timer: - """A timer object which helps to log the execution times, and provides different tools to assess the times. - """ + """A timer object which helps to log the execution times, and provides different tools to assess the times.""" def __init__(self): self._started = False @@ -25,16 +25,14 @@ def current_time(self) -> float: return time.time() def start(self): - """Firstly synchronize cuda, reset the clock and then start the timer. - """ + """Firstly synchronize cuda, reset the clock and then start the timer.""" self._elapsed = 0 synchronize() self._start_time = time.time() self._started = True def lap(self): - """lap time and return elapsed time - """ + """lap time and return elapsed time""" return self.current_time - self._start_time def stop(self, keep_in_history: bool = False): @@ -80,12 +78,11 @@ def get_elapsed_time(self): Note: Use it only when timer is not in progress """ - assert not self._started, 'Timer is still in progress' + assert not self._started, "Timer is still in progress" return self._elapsed def reset(self): - """Clear up the timer and its history - """ + """Clear up the timer and its history""" self._history = [] self._started = False self._elapsed = 0 diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 4991241b8df1..90d0f8de1916 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -10,6 +10,13 @@ from .wrapper import zero_model_wrapper, zero_optim_wrapper __all__ = [ - 'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', - 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' + "GeminiDDP", + "GeminiOptimizer", + "GeminiAdamOptimizer", + "zero_model_wrapper", + "zero_optim_wrapper", + "LowLevelZeroOptimizer", + "ColoInitContext", + "post_process_colo_init_ctx", + "get_static_torch_model", ] diff --git a/colossalai/zero/gemini/__init__.py b/colossalai/zero/gemini/__init__.py index 7ac6a9be4140..358d5c7fd289 100644 --- a/colossalai/zero/gemini/__init__.py +++ b/colossalai/zero/gemini/__init__.py @@ -6,6 +6,15 @@ from .utils import get_static_torch_model __all__ = [ - 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP', - 'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' + "GeminiManager", + "TensorInfo", + "TensorState", + "ChunkManager", + "search_chunk_configuration", + "GeminiDDP", + "get_static_torch_model", + "GeminiAdamOptimizer", + "GeminiOptimizer", + "ColoInitContext", + "post_process_colo_init_ctx", ] diff --git a/colossalai/zero/gemini/chunk/__init__.py b/colossalai/zero/gemini/chunk/__init__.py index 6914d2dbef45..91906f68ad25 100644 --- a/colossalai/zero/gemini/chunk/__init__.py +++ b/colossalai/zero/gemini/chunk/__init__.py @@ -3,4 +3,4 @@ from .search_utils import classify_params_by_dp_degree, search_chunk_configuration from .utils import init_chunk_manager -__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager'] +__all__ = ["Chunk", "ChunkManager", "classify_params_by_dp_degree", "search_chunk_configuration", "init_chunk_manager"] diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 3e7403adb53b..bbef9013c20b 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -17,12 +17,17 @@ class TensorState(Enum): READY_FOR_REDUCE = 4 -STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.COMPUTE, - TensorState.HOLD), - (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), - (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, - TensorState.HOLD)) +STATE_TRANS = ( + (TensorState.FREE, TensorState.HOLD), + (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), + (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), + (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), + (TensorState.READY_FOR_REDUCE, TensorState.HOLD), +) @dataclass @@ -53,14 +58,16 @@ def alloc_storage(tensor: torch.Tensor) -> None: class Chunk: _total_number = 0 - def __init__(self, - chunk_size: int, - process_group: ProcessGroup, - dtype: torch.dtype, - init_device: Optional[torch.device] = None, - cpu_shard_init: bool = False, - keep_gathered: bool = False, - pin_memory: bool = False) -> None: + def __init__( + self, + chunk_size: int, + process_group: ProcessGroup, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + cpu_shard_init: bool = False, + keep_gathered: bool = False, + pin_memory: bool = False, + ) -> None: """ Chunk: A container owning a piece of contiguous memory space for tensors Here we use all-gather operation to gather the whole chunk. @@ -99,9 +106,9 @@ def __init__(self, device = init_device or get_current_device() # chunk_temp is a global chunk, which only exists during building the chunks. - self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero + self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero - self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA + self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA # cuda local chunk, which is sharded on GPUs self.cuda_shard = None @@ -134,7 +141,7 @@ def __init__(self, # they are treated the same as that of the parameters in DDP during training. self.keep_gathered = keep_gathered if self.keep_gathered: - pin_memory = False # since this chunk is gathered, it doesn't need to pin + pin_memory = False # since this chunk is gathered, it doesn't need to pin # if pin_memory is True, we allocate a piece of CPU pin-memory # for it all the time @@ -160,7 +167,7 @@ def memory_usage(self) -> Dict[str, int]: if self.chunk_temp is not None: # this chunk is not closed - if self.chunk_temp.device.type == 'cuda': + if self.chunk_temp.device.type == "cuda": cuda_memory += self.chunk_mem else: cpu_memory += self.chunk_mem @@ -180,11 +187,11 @@ def device_type(self) -> str: return self.chunk_temp.device.type else: if self.is_gathered: - return 'cuda' + return "cuda" elif self.cuda_shard is not None: - return 'cuda' + return "cuda" else: - return 'cpu' + return "cpu" @property def payload(self) -> torch.Tensor: @@ -217,8 +224,10 @@ def can_release(self) -> bool: if self.keep_gathered: return False else: - return self.tensor_state_cnter[TensorState.HOLD] + \ - self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors + return ( + self.tensor_state_cnter[TensorState.HOLD] + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] + == self.num_tensors + ) @property def can_reduce(self): @@ -226,27 +235,25 @@ def can_reduce(self): @property def has_inf_or_nan(self) -> bool: - """Check if the chunk has inf or nan values on CUDA. - """ + """Check if the chunk has inf or nan values on CUDA.""" if self.is_gathered: - valid_tensor = self.cuda_global_chunk[:self.utilized_size] + valid_tensor = self.cuda_global_chunk[: self.utilized_size] else: - assert self.cuda_shard is not None # only check on CUDA - valid_tensor = self.cuda_shard[:self.valid_end] + assert self.cuda_shard is not None # only check on CUDA + valid_tensor = self.cuda_shard[: self.valid_end] return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() def set_l2_norm(self) -> None: - """Record l2 norm of this chunks on CUDA. - """ + """Record l2 norm of this chunks on CUDA.""" assert self.l2_norm is None, "you are calculating the l2 norm twice" if self.is_gathered: - valid_tensor = self.cuda_global_chunk[:self.utilized_size] + valid_tensor = self.cuda_global_chunk[: self.utilized_size] else: - assert self.cuda_shard is not None # calculate on CUDA - valid_tensor = self.cuda_shard[:self.valid_end] + assert self.cuda_shard is not None # calculate on CUDA + valid_tensor = self.cuda_shard[: self.valid_end] chunk_l2_norm = valid_tensor.data.float().norm(2) - self.l2_norm = chunk_l2_norm.item()**2 + self.l2_norm = chunk_l2_norm.item() ** 2 def append_tensor(self, tensor: torch.Tensor): """Add a tensor to the chunk. @@ -263,9 +270,9 @@ def append_tensor(self, tensor: torch.Tensor): if new_utilized_size > self.chunk_size: raise ChunkFullError - self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) + self.chunk_temp[self.utilized_size : new_utilized_size].copy_(tensor.data.flatten()) assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" - tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) + tensor.data = self.chunk_temp[self.utilized_size : new_utilized_size].view(tensor.shape) # record all the information about the tensor self.num_tensors += 1 @@ -275,8 +282,7 @@ def append_tensor(self, tensor: torch.Tensor): self.utilized_size = new_utilized_size def close_chunk(self): - """Close the chunk. Any tensor can't be appended to a closed chunk later. - """ + """Close the chunk. Any tensor can't be appended to a closed chunk later.""" # sanity check assert self.chunk_temp is not None @@ -286,7 +292,7 @@ def close_chunk(self): elif self.utilized_size < self.shard_end: self.valid_end = self.utilized_size - self.shard_begin - if self.chunk_temp.device.type == 'cpu': + if self.chunk_temp.device.type == "cpu": self.cuda_global_chunk = self.chunk_temp.to(get_current_device()) self.__update_tensors_ptr() else: @@ -298,12 +304,12 @@ def close_chunk(self): if self.keep_gathered: return - if self.pin_memory or self.shard_device.type == 'cpu': + if self.pin_memory or self.shard_device.type == "cpu": self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory) self.cpu_shard.copy_(self.cuda_shard) - self.cpu_vis_flag = True # cpu_shard has been visited + self.cpu_vis_flag = True # cpu_shard has been visited - if self.shard_device.type == 'cpu': + if self.shard_device.type == "cpu": self.cuda_shard = None def shard_move(self, device: torch.device, force_copy: bool = False): @@ -318,12 +324,12 @@ def shard_move(self, device: torch.device, force_copy: bool = False): # when the current chunk is not synchronized with the optimizer # just use another way for the movement if not self.optim_sync_flag: - assert device.type == 'cuda', "each chunk should first be moved to CUDA" + assert device.type == "cuda", "each chunk should first be moved to CUDA" self.__paired_shard_move() self.optim_sync_flag = True return - if device.type == 'cuda': + if device.type == "cuda": assert device == get_current_device(), "can't move chunk to another device" if self.cuda_shard: @@ -333,7 +339,7 @@ def shard_move(self, device: torch.device, force_copy: bool = False): if not self.pin_memory: self.cpu_shard = None - elif device.type == 'cpu': + elif device.type == "cpu": if self.cuda_shard is None: return @@ -350,8 +356,7 @@ def shard_move(self, device: torch.device, force_copy: bool = False): raise NotImplementedError def access_chunk(self): - """Make the chunk usable for the parameters inside it. It's an operation done in CUDA. - """ + """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None @@ -360,8 +365,7 @@ def access_chunk(self): self.__update_tensors_ptr() def release_chunk(self): - """Release the usable chunk. It's an operation done in CUDA. - """ + """Release the usable chunk. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None @@ -369,8 +373,7 @@ def release_chunk(self): self.__scatter() def reduce(self): - """Reduce scatter all the gradients. It's an operation done in CUDA. - """ + """Reduce scatter all the gradients. It's an operation done in CUDA.""" # sanity check assert self.is_gathered @@ -423,20 +426,18 @@ def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Ten assert self.is_gathered tensor_info = self.tensors_info[tensor] - self.cuda_global_chunk[tensor_info.offset:tensor_info.end].copy_(data_slice.data.flatten()) - tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten()) + tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) def get_valid_length(self) -> int: - """Get the valid length of the chunk's payload. - """ + """Get the valid length of the chunk's payload.""" if self.keep_gathered: return self.utilized_size else: return self.valid_end - def init_pair(self, friend_chunk: 'Chunk') -> None: - """Initialize the paired chunk. - """ + def init_pair(self, friend_chunk: "Chunk") -> None: + """Initialize the paired chunk.""" if self.paired_chunk is None and friend_chunk.paired_chunk is None: self.paired_chunk = friend_chunk friend_chunk.paired_chunk = self @@ -445,8 +446,7 @@ def init_pair(self, friend_chunk: 'Chunk') -> None: assert friend_chunk.paired_chunk is self def optim_update(self) -> None: - """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. - """ + """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer.""" # sanity check assert self.paired_chunk is not None @@ -455,15 +455,15 @@ def optim_update(self) -> None: assert friend_chunk.is_gathered is True self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) self.optim_sync_flag = True - elif friend_chunk.device_type == 'cuda' and self.device_type == 'cuda': + elif friend_chunk.device_type == "cuda" and self.device_type == "cuda": self.cuda_shard.copy_(friend_chunk.cuda_shard) self.optim_sync_flag = True self.cpu_vis_flag = False else: # optim_sync_flag is set to False # see shard_move function for more details - assert friend_chunk.device_type == 'cpu' - assert self.device_type == 'cpu' + assert friend_chunk.device_type == "cpu" + assert self.device_type == "cpu" self.optim_sync_flag = False self.cpu_vis_flag = False @@ -492,7 +492,7 @@ def __scatter(self): self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device) - self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin:self.shard_end]) + self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin : self.shard_end]) free_storage(self.cuda_global_chunk) self.is_gathered = False @@ -518,7 +518,7 @@ def __update_tensors_ptr(self) -> None: assert type(self.cuda_global_chunk) == torch.Tensor for tensor, tensor_info in self.tensors_info.items(): - tensor.data = self.cuda_global_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape) + tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): self.tensor_state_cnter[tensor_info.state] -= 1 @@ -539,38 +539,41 @@ def __eq__(self, __o: object) -> bool: def __repr__(self, detailed: bool = True): output = [ "Chunk Information:\n", - "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(self.chunk_size, self.dtype, - self.pg_size), + "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format( + self.chunk_size, self.dtype, self.pg_size + ), "\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format( - self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size) + self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size + ), ] - def print_tensor(tensor, prefix=''): - output.append("{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, - tensor.device)) + def print_tensor(tensor, prefix=""): + output.append( + "{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, tensor.device) + ) if self.chunk_temp is not None: output.append("\tchunk temp:\n") - print_tensor(tensor=self.chunk_temp, prefix='\t\t') + print_tensor(tensor=self.chunk_temp, prefix="\t\t") if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0: output.append("\tchunk total:\n") - print_tensor(tensor=self.cuda_global_chunk, prefix='\t\t') + print_tensor(tensor=self.cuda_global_chunk, prefix="\t\t") if self.cuda_shard is not None: output.append("\tcuda shard:\n") - print_tensor(tensor=self.cuda_shard, prefix='\t\t') + print_tensor(tensor=self.cuda_shard, prefix="\t\t") if self.cpu_shard is not None: output.append("\tcpu shard:\n") - print_tensor(tensor=self.cpu_shard, prefix='\t\t') + print_tensor(tensor=self.cpu_shard, prefix="\t\t") memory_info = self.memory_usage - output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info['cuda'], memory_info['cpu'])) + output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info["cuda"], memory_info["cpu"])) if detailed: output.append("\ttensor state monitor:\n") for st in TensorState: output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st])) - return ''.join(output) + return "".join(output) diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 1e96234326a9..957e41b02d49 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -20,27 +20,28 @@ class ChunkManager: """ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: - self.device = init_device or get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration for k, v in self.kwargs_config.items(): - self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size') - v['init_device'] = self.device + self.dp_degree_chunk_size_dict[k] = v.pop("chunk_size") + v["init_device"] = self.device self.chunk_groups: Dict[str, Deque[Chunk]] = dict() self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() self.accessed_chunks: Set[Chunk] = set() self.accessed_mem: int = 0 - self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} - - def register_tensor(self, - tensor: torch.Tensor, - group_type: str, - config_key: int, - process_group: ProcessGroup, - cpu_offload: bool = False, - pin_memory: bool = False) -> None: + self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0} + + def register_tensor( + self, + tensor: torch.Tensor, + group_type: str, + config_key: int, + process_group: ProcessGroup, + cpu_offload: bool = False, + pin_memory: bool = False, + ) -> None: """ Register a tensor to the chunk manager. Then, the tensor should be accessed by `get_chunks`. @@ -94,25 +95,22 @@ def register_tensor(self, self.tensor_chunk_map[tensor] = chunk_group[-1] def close_all_groups(self): - """Close all the chunks of all groups. - """ + """Close all the chunks of all groups.""" for group_name in self.chunk_groups: self.__close_one_chunk(self.chunk_groups[group_name][-1]) def access_chunk(self, chunk: Chunk) -> None: - """Make the chunk can be used for calculation. - """ + """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: return self.__sub_memory_usage(chunk.memory_usage) - if chunk.device_type == 'cpu': + if chunk.device_type == "cpu": chunk.shard_move(get_current_device()) self.__add_accessed_chunk(chunk) self.__add_memory_usage(chunk.memory_usage) def release_chunk(self, chunk: Chunk) -> None: - """Scatter the chunk in CUDA. - """ + """Scatter the chunk in CUDA.""" if chunk not in self.accessed_chunks: return if chunk.can_release: @@ -121,8 +119,7 @@ def release_chunk(self, chunk: Chunk) -> None: self.__add_memory_usage(chunk.memory_usage) def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = False) -> None: - """Move the shard of the chunk to the target device. - """ + """Move the shard of the chunk to the target device.""" if not chunk.can_move or chunk.device_type == device.type: return self.__sub_memory_usage(chunk.memory_usage) @@ -130,14 +127,12 @@ def move_chunk(self, chunk: Chunk, device: torch.device, force_copy: bool = Fals self.__add_memory_usage(chunk.memory_usage) def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: - """Transit tensor state according to pre-defined state machine. - """ + """Transit tensor state according to pre-defined state machine.""" chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) def reduce_chunk(self, chunk: Chunk) -> bool: - """Reduce or all reduce the chunk. - """ + """Reduce or all reduce the chunk.""" if not chunk.can_reduce: return False self.__sub_memory_usage(chunk.memory_usage) @@ -213,18 +208,17 @@ def add_extern_static_tensor(self, tensor: torch.Tensor) -> None: def __repr__(self) -> str: msg = [ - 'Chunk Manager Information:\n', - 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' + "Chunk Manager Information:\n", + "Total memory: " + ", ".join([f"{k}={v}B" for k, v in self.total_mem.items()]) + "\n", ] for group_name, group in self.chunk_groups.items(): - msg.append(f'Group {group_name}:\n') + msg.append(f"Group {group_name}:\n") for i, chunk in enumerate(group): - msg.append(f'[{i}] {chunk}\n') - return ''.join(msg) + msg.append(f"[{i}] {chunk}\n") + return "".join(msg) def __get_chunk_group(self, group_name: str) -> Deque[Chunk]: - """Register a chunk group. - """ + """Register a chunk group.""" if group_name not in self.chunk_groups: self.chunk_groups[group_name] = deque() return self.chunk_groups[group_name] diff --git a/colossalai/zero/gemini/chunk/search_utils.py b/colossalai/zero/gemini/chunk/search_utils.py index abaca5f8294d..24d8537bad90 100644 --- a/colossalai/zero/gemini/chunk/search_utils.py +++ b/colossalai/zero/gemini/chunk/search_utils.py @@ -76,8 +76,9 @@ def _tensor_numel(local_param: ColoParameter) -> int: return local_param.numel() -def classify_params_by_dp_degree(param_order: OrderedParamGenerator, - process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]: +def classify_params_by_dp_degree( + param_order: OrderedParamGenerator, process_group: ProcessGroup +) -> Dict[int, List[ColoParameter]]: """classify_params_by_dp_degree Classify the parameters by their dp degree @@ -105,14 +106,15 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator, def search_chunk_configuration( - model: nn.Module, - search_range_m: float, - search_interval: int, # hidden size is the best value for the interval - min_chunk_size_m: float = 32, - filter_exlarge_params: bool = True, - strict_ddp_flag: bool = False, - process_group: Optional[ProcessGroup] = None, - memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: + model: nn.Module, + search_range_m: float, + search_interval: int, # hidden size is the best value for the interval + min_chunk_size_m: float = 32, + filter_exlarge_params: bool = True, + strict_ddp_flag: bool = False, + process_group: Optional[ProcessGroup] = None, + memstas: Optional[MemStats] = None, +) -> Tuple[Dict, int, int]: """search_chunk_configuration Search the chunk configuration for a model. @@ -168,7 +170,7 @@ def search_chunk_configuration( max_size = max(max_size, max(size_dict[key])) start_size = int(math.ceil(max_size / search_interval) * search_interval) - min_chunk_waste = float('+inf') + min_chunk_waste = float("+inf") best_chunk_size = start_size for chunk_size in range(start_size, start_size + search_range + 1, search_interval): diff --git a/colossalai/zero/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py index e98e9cf9c314..7a2ea360650b 100644 --- a/colossalai/zero/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -5,8 +5,6 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.utils import is_ddp_ignored - from .manager import ChunkManager from .search_utils import search_chunk_configuration @@ -17,15 +15,17 @@ def safe_div(a, b): return a / b -def init_chunk_manager(model: nn.Module, - init_device: Optional[torch.device] = None, - hidden_dim: Optional[int] = None, - verbose: bool = False, - **kwargs) -> ChunkManager: +def init_chunk_manager( + model: nn.Module, + init_device: Optional[torch.device] = None, + hidden_dim: Optional[int] = None, + verbose: bool = False, + **kwargs, +) -> ChunkManager: if hidden_dim: search_interval = hidden_dim else: - search_interval = 1024 # defaults to 1024 + search_interval = 1024 # defaults to 1024 kwargs["search_interval"] = search_interval dist.barrier() @@ -41,11 +41,13 @@ def init_chunk_manager(model: nn.Module, wasted_size /= mega_unit if verbose and dist.get_rank() == 0: - print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), - "used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size), - "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), - sep='', - flush=True) + print( + "searching chunk configuration is completed in {:.2f} s.\n".format(span_s), + "used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size), + "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)), + sep="", + flush=True, + ) dist.barrier() chunk_manager = ChunkManager(config_dict, init_device) diff --git a/colossalai/zero/gemini/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py index 549635af4332..ab2ff8f920aa 100644 --- a/colossalai/zero/gemini/colo_init_context.py +++ b/colossalai/zero/gemini/colo_init_context.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import Any, Iterator, Optional, Tuple, Union import torch from torch import nn @@ -12,7 +12,7 @@ def _named_params_with_replica( module: nn.Module, - prefix: str = '', + prefix: str = "", recurse: bool = True, ) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] @@ -21,16 +21,17 @@ def _named_params_with_replica( for name, val in mod._parameters.items(): if val is None: continue - name = mod_prefix + ('.' if mod_prefix else '') + name + name = mod_prefix + ("." if mod_prefix else "") + name yield name, val -def _convert_to_coloparam(param: torch.nn.Parameter, - device: torch.device, - dtype=torch.float, - default_pg: Optional[ProcessGroup] = None, - default_dist_spec: Optional[Any] = None) -> ColoParameter: - +def _convert_to_coloparam( + param: torch.nn.Parameter, + device: torch.device, + dtype=torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec: Optional[Any] = None, +) -> ColoParameter: if type(param) is ColoParameter: return param # detaching tensor is necessary for optimizers. @@ -66,12 +67,13 @@ def ColoModulize(module): class ColoInitContext(InsertPostInitMethodToModuleSubClasses): - - def __init__(self, - device: torch.device = torch.device('cpu'), - dtype: torch.dtype = torch.float, - default_pg: Optional[ProcessGroup] = None, - default_dist_spec=None): + def __init__( + self, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None, + ): """ Args: device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu'). @@ -89,6 +91,7 @@ def __init__(self, def _register_colo_modules(self): from colossalai.legacy.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module + register_colo_module(torch.nn.Linear, ColoLinear()) register_colo_module(torch.nn.Embedding, ColoEmbedding()) @@ -105,25 +108,25 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): if type(param) is ColoParameter: continue - split = name.rfind('.') - if split >= 0: # param in submodule + split = name.rfind(".") + if split >= 0: # param in submodule module_name = name[:split] - param_name = name[split + 1:] + param_name = name[split + 1 :] else: - module_name = '' # param in current module + module_name = "" # param in current module param_name = name name_list.append((module_name, param_name)) - replaced_tensors = dict( - ) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference + replaced_tensors = dict() # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference for module_name, param_name in name_list: submodule = module.get_submodule(module_name) param = submodule.get_parameter(param_name) if param in replaced_tensors: colo_param = replaced_tensors[param] else: - colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg, - self._default_dist_spec) + colo_param = _convert_to_coloparam( + param, self._device, self._dtype, self._default_pg, self._default_dist_spec + ) replaced_tensors[param] = colo_param delattr(submodule, param_name) setattr(submodule, param_name, colo_param) @@ -136,11 +139,11 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): for param in module.parameters(): param_number += 1 - meta_param_number += (param.device.type == 'meta') + meta_param_number += param.device.type == "meta" for buffer in module.buffers(): buffer_number += 1 - meta_buffer_number += (buffer.device.type == 'meta') + meta_buffer_number += buffer.device.type == "meta" if meta_param_number > 0 and meta_param_number != param_number: raise ValueError("Meta parameters and valued parameters can not be in the same model") @@ -152,11 +155,13 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): buffer.data = buffer.data.to(device=self._device) -def post_process_colo_init_ctx(model: torch.nn.Module, - device: torch.device = torch.device('cpu'), - dtype: torch.dtype = torch.float, - default_pg: Optional[ProcessGroup] = None, - default_dist_spec=None): +def post_process_colo_init_ctx( + model: torch.nn.Module, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float, + default_pg: Optional[ProcessGroup] = None, + default_dist_spec=None, +): """post_process_colo_init_ctx This function is called after `ColoInitContext`. @@ -178,8 +183,8 @@ def post_process_colo_init_ctx(model: torch.nn.Module, # print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter") torch_params.append((n, p)) - for (n, param) in torch_params: - name_list = n.split('.') + for n, param in torch_params: + name_list = n.split(".") module = model for i in range(len(name_list) - 1): module = module._modules[name_list[i]] diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 918b08cd3150..580b497ce719 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -10,7 +10,7 @@ from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group -from colossalai.checkpoint_io.utils import StateDictSharder, calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger @@ -27,10 +27,10 @@ try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" __all__ = [ - 'GeminiDDP', + "GeminiDDP", ] @@ -54,27 +54,28 @@ class GeminiDDP(ModelWrapper): """ def __init__( - self, - module: torch.nn.Module, - chunk_config_dict: Optional[dict] = None, - chunk_init_device: torch.device = torch.device('cpu'), - placement_policy: str = "static", - shard_param_frac: float = 1.0, # only for static placement - offload_optim_frac: float = 0.0, # only for static placement - offload_param_frac: float = 0.0, # only for static placement - warmup_non_model_data_ratio: float = 0.8, # only for auto placement - steady_cuda_cap_ratio: float = 0.9, # only for auto placement - search_range_m: int = 32, # chunk search options - hidden_dim: Optional[int] = None, # chunk search options - min_chunk_size_m: float = 32, # chunk search options - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False, - scatter_after_inference: bool = True, - mixed_precision: torch.dtype = torch.float16, - process_group: Optional[ProcessGroup] = None, - memstats: Optional[MemStats] = None, # genimi memory stats - verbose: bool = False) -> None: + self, + module: torch.nn.Module, + chunk_config_dict: Optional[dict] = None, + chunk_init_device: torch.device = torch.device("cpu"), + placement_policy: str = "static", + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + warmup_non_model_data_ratio: float = 0.8, # only for auto placement + steady_cuda_cap_ratio: float = 0.9, # only for auto placement + search_range_m: int = 32, # chunk search options + hidden_dim: Optional[int] = None, # chunk search options + min_chunk_size_m: float = 32, # chunk search options + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + scatter_after_inference: bool = True, + mixed_precision: torch.dtype = torch.float16, + process_group: Optional[ProcessGroup] = None, + memstats: Optional[MemStats] = None, # genimi memory stats + verbose: bool = False, + ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) if chunk_config_dict is not None: self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device) @@ -82,22 +83,26 @@ def __init__( # some ugly hotfix for the compatibility with Lightning if search_range_m is None: search_range_m = 32 - self.chunk_manager = init_chunk_manager(model=module, - init_device=chunk_init_device, - hidden_dim=hidden_dim, - search_range_m=search_range_m, - min_chunk_size_m=min_chunk_size_m, - strict_ddp_flag=strict_ddp_mode, - process_group=process_group, - verbose=verbose) - self.gemini_manager = GeminiManager(placement_policy, - self.chunk_manager, - memstats, - shard_param_frac=shard_param_frac, - offload_optim_frac=offload_optim_frac, - offload_param_frac=offload_param_frac, - warmup_non_model_data_ratio=warmup_non_model_data_ratio, - steady_cuda_cap_ratio=steady_cuda_cap_ratio) + self.chunk_manager = init_chunk_manager( + model=module, + init_device=chunk_init_device, + hidden_dim=hidden_dim, + search_range_m=search_range_m, + min_chunk_size_m=min_chunk_size_m, + strict_ddp_flag=strict_ddp_mode, + process_group=process_group, + verbose=verbose, + ) + self.gemini_manager = GeminiManager( + placement_policy, + self.chunk_manager, + memstats, + shard_param_frac=shard_param_frac, + offload_optim_frac=offload_optim_frac, + offload_param_frac=offload_param_frac, + warmup_non_model_data_ratio=warmup_non_model_data_ratio, + steady_cuda_cap_ratio=steady_cuda_cap_ratio, + ) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() @@ -126,13 +131,15 @@ def __init__( self.param2name[param] = name for m_name, m_var in module.named_modules(): for p_name, p_var in m_var.named_parameters(recurse=False): - param_name = m_name + '.' + p_name if m_name else p_name + param_name = m_name + "." + p_name if m_name else p_name self.name2param[param_name] = p_var - self._init_chunks(param_order=param_order, - strict_ddp_mode=strict_ddp_mode, - cpu_offload=self.gemini_manager.policy_name != 'cuda', - pin_memory=pin_memory) + self._init_chunks( + param_order=param_order, + strict_ddp_mode=strict_ddp_mode, + cpu_offload=self.gemini_manager.policy_name != "cuda", + pin_memory=pin_memory, + ) super().__init__(module) self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._cast_buffers() @@ -146,19 +153,18 @@ def __init__( def parameters(self, recurse: bool = True): return self.module.parameters(recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True): + def named_parameters(self, prefix: str = "", recurse: bool = True): return self.module.named_parameters(prefix, recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True): + def named_buffers(self, prefix: str = "", recurse: bool = True): return self.module.named_buffers(prefix, recurse) def named_children(self): return self.module.named_children() - def named_modules(self, - memo: Optional[Set[torch.nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def named_modules( + self, memo: Optional[Set[torch.nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): return self.module.named_modules(memo, prefix, remove_duplicate) @staticmethod @@ -184,11 +190,9 @@ def unwrap(self): # as save/load state dict is overwrited, only return self return self - def _get_non_persistent_buffers_set(self, - module, - memo: Optional[Set[nn.Module]] = None, - prefix: str = '', - remove_duplicate: bool = True): + def _get_non_persistent_buffers_set( + self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True + ): r""" Args: memo: a memo to store the set of modules already added to the result @@ -204,19 +208,20 @@ def _get_non_persistent_buffers_set(self, if remove_duplicate: memo.add(module) self_non_persistent_set = set( - map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) + map(lambda key: prefix + ("." if prefix else "") + key, module._non_persistent_buffers_set) + ) for name, sub_module in module._modules.items(): if sub_module is None: continue - submodule_prefix = prefix + ('.' if prefix else '') + name - child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, - remove_duplicate) + submodule_prefix = prefix + ("." if prefix else "") + name + child_non_persistent_set = self._get_non_persistent_buffers_set( + sub_module, memo, submodule_prefix, remove_duplicate + ) self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) return self_non_persistent_set def _post_forward(self): - """This function is only triggered for inference. - """ + """This function is only triggered for inference.""" access_list = list(self.chunk_manager.accessed_chunks) # we need to scatter all accessed chunks and move them to their original places for chunk in access_list: @@ -233,7 +238,8 @@ def forward(self, *args, **kwargs): # check whether we are in a inference mode grad_flag = torch.is_grad_enabled() if not grad_flag: - assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( + assert ( + not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup() ), "You should run a completed iteration as your warmup iter" args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision) @@ -250,8 +256,7 @@ def forward(self, *args, **kwargs): return outputs def _inference_forward(self, *args, **kwargs): - """This function is only triggered for inference. - """ + """This function is only triggered for inference.""" fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook) if not self.scatter_after_inference: # gather all chunks @@ -287,12 +292,14 @@ def _post_backward(self): if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"): error_params.append(self.param2name[param]) error_str = "\n\t".join(error_params) - raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", - "The most possible reason is that the model is not compatible with GeminiDDP.\n", - f"{error_str}") + raise RuntimeError( + "ZERO DDP error: the synchronization of gradients doesn't exit properly.", + "The most possible reason is that the model is not compatible with GeminiDDP.\n", + f"{error_str}", + ) self._setup_grads_ptr() self._logger.debug( - f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' + f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}" ) self.gemini_manager.post_iter() @@ -314,8 +321,10 @@ def grad_handle(self, p, grad): with torch._C.DisableTorchFunction(): chunk = self.chunk_manager.get_chunk(p) if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: - raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " - "Some unsupported torch function is operated upon this parameter.") + raise RuntimeError( + f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " + "Some unsupported torch function is operated upon this parameter." + ) self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) chunk.copy_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(chunk) @@ -339,12 +348,9 @@ def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, - destination=None, - prefix='', - keep_vars=False, - only_rank_0: bool = True, - dtype: torch.dtype = torch.float16): + def state_dict( + self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16 + ): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. @@ -391,7 +397,7 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch. record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu() assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor @@ -399,8 +405,9 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch. del temp_chunk return chunk_to_save_data - def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool, - dtype: torch.dtype) -> Dict: + def _get_param_to_save_data( + self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype + ) -> Dict: """ get param content from chunks. @@ -459,11 +466,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, destination[prefix + name] = buf if keep_vars else buf.detach() # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): destination[extra_state_key] = self.get_extra_state() - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned @@ -491,32 +500,38 @@ def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: error_msgs: List[str] = [] # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) + metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: # mypy isn't aware that "_metadata" exists in state_dict - state_dict._metadata = metadata # type: ignore[attr-defined] + state_dict._metadata = metadata # type: ignore[attr-defined] - prefix = '' + prefix = "" local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) if strict: if len(unexpected_keys) > 0: error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in unexpected_keys))) + 0, + "Unexpected key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in unexpected_keys) + ), + ) if len(missing_keys) > 0: error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + 0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)) + ) if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs)) + ) return _IncompatibleKeys(missing_keys, unexpected_keys) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): r"""Copies parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this @@ -564,19 +579,21 @@ def load(param_name, dest_tensor, copy_func): input_param = input_param[0] if input_param.shape != dest_tensor.shape: # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format(state_key, input_param.shape, - dest_tensor.shape)) + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format(state_key, input_param.shape, dest_tensor.shape) + ) return try: with torch.no_grad(): copy_func(input_param) except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), - input_param.size(), ex.args)) + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format(state_key, dest_tensor.size(), input_param.size(), ex.args) + ) elif strict: missing_keys.append(state_key) @@ -600,15 +617,15 @@ def load_fp32_parameter(chunk_slice, data): for tensor, tensor_info in chunk.tensors_info.items(): parameter_name = fp32_to_name[tensor] - parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] + parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) if chunk.is_gathered: chunk.cuda_global_chunk.copy_(temp_chunk) elif chunk.cuda_shard is not None: - chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) else: - chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) del temp_chunk @@ -622,8 +639,10 @@ def load_fp32_parameter(chunk_slice, data): load(name, buf, buf.copy_) extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "set_extra_state", - torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: + if ( + getattr(self.__class__, "set_extra_state", torch.nn.Module.set_extra_state) + is not torch.nn.Module.set_extra_state + ): if extra_state_key in state_dict: self.set_extra_state(state_dict[extra_state_key]) elif strict: @@ -634,7 +653,7 @@ def load_fp32_parameter(chunk_slice, data): if strict: for key in state_dict.keys(): if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] + input_name = key[len(prefix) :] if input_name not in local_state: unexpected_keys.append(key) @@ -659,18 +678,22 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi p.data = p.data.to(self.mixed_precision) # register the fp16 parameter and fp32 parameter in the chunk manager - self.chunk_manager.register_tensor(tensor=p, - group_type='fp16_param', - config_key=dp_world_size, - process_group=self.dp_process_group, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - self.chunk_manager.register_tensor(tensor=fp32_p, - group_type='fp32_param', - config_key=dp_world_size, - process_group=self.dp_process_group, - cpu_offload=cpu_offload, - pin_memory=pin_memory) + self.chunk_manager.register_tensor( + tensor=p, + group_type="fp16_param", + config_key=dp_world_size, + process_group=self.dp_process_group, + cpu_offload=cpu_offload, + pin_memory=pin_memory, + ) + self.chunk_manager.register_tensor( + tensor=fp32_p, + group_type="fp32_param", + config_key=dp_world_size, + process_group=self.dp_process_group, + cpu_offload=cpu_offload, + pin_memory=pin_memory, + ) self.fp16_params.append(p) self.fp32_params.append(fp32_p) @@ -694,7 +717,7 @@ def _cast_buffers(self): if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) - def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None: + def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, "LazyTensor"]) -> None: """Convert parameter to ColoParameter in-place. Args: p (Union[nn.Parameter, ColoParameter, LazyTensor]): parameter to be converted @@ -709,12 +732,14 @@ def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) p.__class__ = ColoParameter p.__init__(p, requires_grad=requires_grad) - def state_dict_shard(self, - prefix: str = '', - keep_vars: bool = False, - max_shard_size: int = 1024, - only_rank_0: bool = True, - dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]: + def state_dict_shard( + self, + prefix: str = "", + keep_vars: bool = False, + max_shard_size: int = 1024, + only_rank_0: bool = True, + dtype: torch.dtype = torch.float16, + ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Both parameters and persistent buffers (e.g. running averages) are included. @@ -770,8 +795,10 @@ def state_dict_shard(self, yield block, block_size # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + if ( + getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): extra_state = self.get_extra_state() block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index dbc2924858e6..480a14511b69 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -17,7 +17,6 @@ class TrainingPhase(Enum): class GeminiZeROHook(ColoParamOpHook): - def __init__(self, gemini_manager: GeminiManager) -> None: super().__init__() self._gemini_manager = gemini_manager @@ -40,7 +39,11 @@ def pre_op(self, params): def post_op(self, params): params = [p for p in params if not is_ddp_ignored(p)] for p in params: - tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD + tensor_state = ( + TensorState.HOLD + if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad + else TensorState.HOLD_AFTER_BWD + ) self._chunk_manager.trans_tensor_state(p, tensor_state) def pre_forward(self, params: List[torch.Tensor]) -> None: diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index b8e4717908f7..f7ff3f6cdd86 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -26,12 +26,13 @@ class GeminiManager: memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration. """ - def __init__(self, - placement_policy: str, - chunk_manager: ChunkManager, - memstats: Optional[MemStats] = None, - **placement_kwargs) -> None: - + def __init__( + self, + placement_policy: str, + chunk_manager: ChunkManager, + memstats: Optional[MemStats] = None, + **placement_kwargs, + ) -> None: assert placement_policy in PlacementPolicyFactory.get_policy_names() self.policy_name = placement_policy policy_cls = PlacementPolicyFactory.create(placement_policy) @@ -39,8 +40,9 @@ def __init__(self, self._premade_memstats_ = memstats is not None self._memstats = memstats - self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, - self._memstats) if policy_cls.need_mem_stats else None + self._mem_stats_collector = ( + ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None + ) self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 @@ -62,7 +64,7 @@ def reset_attributes(self): @property def need_warmup(self) -> bool: - return self.policy_name in ('auto', 'const') + return self.policy_name in ("auto", "const") def is_warmup(self): return self._warmup @@ -85,15 +87,14 @@ def pre_iter(self, *args): self._mem_stats_collector.start_collection() def post_iter(self): - """This function must be called when each iteration finishes - """ + """This function must be called when each iteration finishes""" if self._mem_stats_collector and self._warmup: self._mem_stats_collector.finish_collection() self._warmup = False self.reset_attributes() def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: - """ Adjust the layout of stateful tensors according to the information provided + """Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE @@ -102,11 +103,13 @@ def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) self._layout_time += time() - start - vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list, - cuda_demand=cuda_demand, - warmup=self._warmup, - compute_list=self._compute_list, - compute_idx=self._compute_idx) + vol, evict_time = self._placement_policy.evict_tensors( + can_evict_chunks=hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx, + ) self._d2h_volume += vol self._evict_time += evict_time @@ -118,12 +121,12 @@ def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, start = time() cuda_demand = 0 for chunk in chunks: - if chunk.device_type == 'cuda': + if chunk.device_type == "cuda": if chunk.is_gathered: pass else: cuda_demand += chunk.chunk_mem - chunk.shard_mem - elif chunk.device_type == 'cpu': + elif chunk.device_type == "cpu": cuda_demand += chunk.chunk_mem else: raise RuntimeError @@ -159,6 +162,7 @@ def cuda_margin_mem(self) -> Optional[float]: def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats - def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, - torch.device]) -> None: + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: self._placement_policy.setup_grads_device(params, grads_device_map) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 0c593deff225..d785eda2dc12 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -10,34 +10,35 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam -from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.utils import disposable, get_current_device, is_ddp_ignored from .chunk import Chunk, ChunkManager from .gemini_ddp import GeminiDDP -__all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer'] +__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): - - def __init__(self, - module: GeminiDDP, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32) -> None: - super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, - max_scale) + def __init__( + self, + module: GeminiDDP, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) self.module = module def check_local_overflow(self) -> bool: @@ -77,25 +78,28 @@ class GeminiOptimizer(OptimizerWrapper): verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False. """ - def __init__(self, - optim: Optimizer, - module: GeminiDDP, - gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**32, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0, - verbose: bool = False, - **defaults: Any): + def __init__( + self, + optim: Optimizer, + module: GeminiDDP, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + verbose: bool = False, + **defaults: Any, + ): super().__init__(optim) assert isinstance(module, GeminiDDP) - assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \ - f"{_AVAIL_OPTIM_LIST}" + assert type(optim) in _AVAIL_OPTIM_LIST, ( + "You should use an optimizer in the available list:\n" f"{_AVAIL_OPTIM_LIST}" + ) self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager @@ -118,8 +122,10 @@ def __init__(self, for name, param in module.named_parameters(): if is_ddp_ignored(param): if param.requires_grad: - warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! " - "You should handle its optimizer update by yourself!") + warnings.warn( + f"Parameter `{name}` is ignored by DDP but requires gradient! " + "You should handle its optimizer update by yourself!" + ) else: ddp_param_list.append(param) @@ -132,14 +138,16 @@ def __init__(self, self.__init__optimizer() if module.mixed_precision is torch.float16: - self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module, - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin( + module, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) elif module.mixed_precision is torch.bfloat16: self.mix_precision_mixin = BF16MixedPrecisionMixin() else: @@ -148,12 +156,15 @@ def __init__(self, self._logger = get_dist_logger() self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) - assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' + assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0" # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, # and it must set `num_fp32_shards_per_param` correctly - self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr( - optim, 'num_fp32_shards_per_param', 0) >= 2 + self._should_move_fp32_params_h2d: bool = ( + self.gemini_manager.is_cuda_margin_mem_avail + and self.gpu_margin_mem_ratio > 0.0 + and getattr(optim, "num_fp32_shards_per_param", 0) >= 2 + ) if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail: self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0]) @@ -161,7 +172,7 @@ def __init__(self, def _set_grad_ptr(self): for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: chunk32 = self.param_to_chunk32[fake_param] begin, end = self.param_to_range[fake_param] chunk16 = chunk32.paired_chunk @@ -173,7 +184,7 @@ def _set_grad_ptr(self): def _update_fp16_params(self): none_tensor = torch.empty([0]) for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: assert fake_param.grad is None fake_param.data = none_tensor.to(fake_param.device) @@ -198,7 +209,7 @@ def _calc_global_norm(self) -> float: group_to_norm[c16.torch_pg] = 0.0 group_to_norm[c16.torch_pg] += c16.l2_norm - c16.l2_norm = None # clear l2 norm + c16.l2_norm = None # clear l2 norm comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) for group, part_norm in group_to_norm.items(): @@ -230,9 +241,9 @@ def step(self, *args, **kwargs): if self.mix_precision_mixin.should_skip_step(): if self.verbose: - self._logger.info(f'Found overflow. Skip step') - self._clear_global_norm() # clear recorded norm - self.zero_grad() # reset all gradients + self._logger.info(f"Found overflow. Skip step") + self._clear_global_norm() # clear recorded norm + self.zero_grad() # reset all gradients self._update_fp16_params() return @@ -269,11 +280,11 @@ def _maybe_move_fp32_params(self): fp32_params_used_cuda_margin_mem = 0 for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: chunk32 = self.param_to_chunk32[fake_param] chunk16 = chunk32.paired_chunk - if chunk32.device_type == 'cuda': + if chunk32.device_type == "cuda": continue if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: @@ -284,9 +295,9 @@ def _maybe_move_fp32_params(self): fp32_params_used_cuda_margin_mem += chunk32.payload_mem for group in self.param_groups: - for fake_param in group['params']: + for fake_param in group["params"]: chunk32 = self.param_to_chunk32[fake_param] - if chunk32.device_type == 'cuda': + if chunk32.device_type == "cuda": state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -294,14 +305,13 @@ def _maybe_move_fp32_params(self): def _register_states_(self): for group in self.optim.param_groups: - for p in group['params']: + for p in group["params"]: state = self.optim.state[p] for val in state.values(): if isinstance(val, torch.Tensor): self.chunk_manager.add_extern_static_tensor(val) def __init__optimizer(self): - def get_range_pair(local_chunk: Chunk, local_param: Parameter): param_info = local_chunk.tensors_info[local_param] if local_chunk.keep_gathered: @@ -313,10 +323,9 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter): param_id = -1 for group in self.optim.param_groups: fake_params_list = list() - group_backup = {k: v for k, v in group.items() if k != 'params'} + group_backup = {k: v for k, v in group.items() if k != "params"} group_ids = [] - for param in group['params']: - + for param in group["params"]: # Record the mapping of id to current param. param_id += 1 self.id_to_real_params[param_id] = param @@ -337,12 +346,12 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter): fake_params_list.append(fake_param) # Update self.optim.param_groups as well as backup group. - group['params'] = fake_params_list - group_backup['params'] = group_ids + group["params"] = fake_params_list + group_backup["params"] = group_ids self.param_groups_backup.append(group_backup) def get_offsets(self, param_id: int) -> tuple: - ''' + """ Args: param_id(int): The id of parameter. @@ -351,7 +360,7 @@ def get_offsets(self, param_id: int) -> tuple: shard_offset(int): Offset of its optimizer state shard relative to the whole optimizer state. shard_size(int): Length of parameter shard owned by current process. - ''' + """ if param_id not in self.id_to_fake_params: return -1, -1, -1 @@ -425,11 +434,11 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: if is_collector: states = self.optim.state[fake_param] for state_name in state_names: - if state_name == 'step': + if state_name == "step": # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. - collected_states[state_name] = torch.tensor(states['step'], - dtype=torch.float32, - requires_grad=False).cpu() + collected_states[state_name] = torch.tensor( + states["step"], dtype=torch.float32, requires_grad=False + ).cpu() else: state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() collected_states[state_name] = torch.reshape(state_tensor, param.shape) @@ -441,12 +450,13 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: # Collector gets prepared for state collecting. if is_collector: for state_name in state_names: - if state_name == 'step': + if state_name == "step": # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu() else: - collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32, - requires_grad=False).cpu() + collected_states[state_name] = torch.zeros( + param.numel(), dtype=torch.float32, requires_grad=False + ).cpu() # Materials for gathering, including compacted state tensors, and the offset of shard inside each state. compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None @@ -465,8 +475,9 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: shard_size = state_shard[2] if compacted_states is None: continue - self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset, - shard_size) + self.load_from_compacted_states( + compacted_states, collected_states, state_names, shard_offset, shard_size + ) # Reshape tensors if is_collector: @@ -476,14 +487,16 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: return collected_states - def pack_optimizer_states_to_tensor(self, - param_id: int, - state_names: list, - device: torch.device = torch.device('cuda'), - dtype: torch.dtype = torch.float32) -> torch.Tensor: - ''' + def pack_optimizer_states_to_tensor( + self, + param_id: int, + state_names: list, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ With param id given, pack its optimizer states into a compact tensor and return. - ''' + """ if param_id not in self.id_to_fake_params: return None @@ -493,7 +506,7 @@ def pack_optimizer_states_to_tensor(self, shard_size = param_range[1] - param_range[0] compacted_size = 0 for name in state_names: - if name == 'step': + if name == "step": compacted_size += 1 else: compacted_size += shard_size @@ -502,7 +515,7 @@ def pack_optimizer_states_to_tensor(self, next_state_offset = 0 for state_name, state_tensor in states.items(): # State 'step' needs special operation. - if state_name == 'step': + if state_name == "step": if isinstance(state_tensor, torch.Tensor): compacted_states[next_state_offset] = state_tensor[0].item() else: @@ -511,47 +524,53 @@ def pack_optimizer_states_to_tensor(self, next_state_offset += 1 else: assert state_tensor.numel() == shard_size - compacted_states[next_state_offset:next_state_offset + shard_size].copy_(state_tensor) + compacted_states[next_state_offset : next_state_offset + shard_size].copy_(state_tensor) next_state_offset += shard_size return compacted_states - def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list, - shard_start: int, shard_size: int): - ''' + def load_from_compacted_states( + self, + compacted_states: torch.Tensor, + collected_states: dict, + state_names: list, + shard_start: int, + shard_size: int, + ): + """ Given a tensor carrying compacted optimizer states, update these states to collected_states. - ''' + """ shard_end = shard_start + shard_size next_state_offset = 0 for state_name in state_names: - if state_name == 'step': - collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(), - dtype=torch.float32, - requires_grad=False).cpu() + if state_name == "step": + collected_states["step"].data = torch.tensor( + compacted_states[next_state_offset].item(), dtype=torch.float32, requires_grad=False + ).cpu() next_state_offset += 1 else: target_segment = collected_states[state_name][shard_start:shard_end] - target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size]) + target_segment.copy_(compacted_states[next_state_offset : next_state_offset + shard_size]) next_state_offset += shard_size def get_param_groups_for_saving(self) -> list: - ''' + """ Return the param_groups in Pytorch format when saving to checkpoint. - ''' + """ param_groups = copy.deepcopy(self.param_groups_backup) # To be compatible with pytorch checkpointing, # store extra hyperparameters used by pytorch Adam optimizer. torch_special_hyperparameters = { - 'amsgrad': False, - 'maximize': False, - 'foreach': None, - 'capturable': False, - 'differentiable': False, - 'fused': False + "amsgrad": False, + "maximize": False, + "foreach": None, + "capturable": False, + "differentiable": False, + "fused": False, } for group in param_groups: @@ -580,13 +599,13 @@ def state_dict(self, only_rank_0: bool = True) -> dict: so it should be called only when memory resources are abundant. """ state_dict = {} - state_dict['param_groups'] = self.get_param_groups_for_saving() + state_dict["param_groups"] = self.get_param_groups_for_saving() # Collect optimizer states. - state_dict['state'] = dict() + state_dict["state"] = dict() for param_id in self.id_to_real_params.keys(): dist.barrier() - state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) + state_dict["state"][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) return state_dict def load_param_groups(self, saved_param_groups: list): @@ -601,13 +620,13 @@ def load_param_groups(self, saved_param_groups: list): for group in saved_param_groups: fake_params_list = list() - updated_group = {k: v for k, v in group.items() if k != 'params'} - for param_id in group['params']: + updated_group = {k: v for k, v in group.items() if k != "params"} + for param_id in group["params"]: if param_id not in self.id_to_fake_params: continue fake_param = self.id_to_fake_params[param_id] fake_params_list.append(fake_param) - updated_group['params'] = fake_params_list + updated_group["params"] = fake_params_list self.optim.param_groups.append(updated_group) def load_single_param_states(self, param_id: int, saved_states: dict): @@ -621,15 +640,14 @@ def cast(param, state_range, value, key=None): """ assert isinstance(value, torch.Tensor) ret_val = value - if (key == "step"): + if key == "step": assert value.numel() == 1 ret_val = int(value.item()) else: state_start, state_end = state_range - ret_val = torch.zeros(state_end - state_start, - dtype=torch.float32, - device=param.device, - requires_grad=False) + ret_val = torch.zeros( + state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False + ) ret_val.copy_(value.flatten()[state_start:state_end]) return ret_val @@ -642,7 +660,7 @@ def cast(param, state_range, value, key=None): updated_states = dict() for k, v in saved_states.items(): updated_states[k] = cast(fake_param, state_range, v, k) - del v # clean loaded states + del v # clean loaded states self.optim.state[fake_param].update(updated_states) def load_param_states(self, param_states: dict): @@ -658,8 +676,8 @@ def load_param_states(self, param_states: dict): def optimizer_loading_epilogue(self): # Epilogue when loading state_dict to pytorch optimizer. - self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle. - self.optim.defaults.setdefault('differentiable', False) + self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle. + self.optim.defaults.setdefault("differentiable", False) def load_state_dict(self, state_dict: dict): """Loads optimizer state from complete optimizer state_dict. @@ -669,16 +687,15 @@ def load_state_dict(self, state_dict: dict): state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. """ - assert 'param_groups' in state_dict - assert 'state' in state_dict - self.load_param_groups(state_dict['param_groups']) - self.load_param_states(state_dict['state']) + assert "param_groups" in state_dict + assert "state" in state_dict + self.load_param_groups(state_dict["param_groups"]) + self.load_param_states(state_dict["state"]) self.optimizer_loading_epilogue() - def state_shard(self, - prefix: str = '', - max_shard_size: int = 1024, - only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]: + def state_shard( + self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True + ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing shards of optimizer states one by one. The max size of each dictionary shard is specified by ``max_shard_size``. @@ -694,7 +711,6 @@ def state_shard(self, sharder = StateDictSharder(max_shard_size) for param_id in self.id_to_real_params.keys(): - dist.barrier() state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) @@ -705,19 +721,20 @@ def state_shard(self, yield sharder.current_block, sharder.current_block_size def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: - raise NotImplementedError('Gemini does not support clip_grad_by_value') + raise NotImplementedError("Gemini does not support clip_grad_by_value") - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> torch.Tensor: - warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm') + def clip_grad_by_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2, + error_if_nonfinite: bool = False, + *args, + **kwargs, + ) -> torch.Tensor: + warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") class GeminiAdamOptimizer(GeminiOptimizer): - def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: optimizer = HybridAdam(model.parameters(), **defaults) super().__init__(optimizer, model, **defaults) diff --git a/colossalai/zero/gemini/memory_tracer/__init__.py b/colossalai/zero/gemini/memory_tracer/__init__.py index e1fe904ebf1a..cb7f626ff446 100644 --- a/colossalai/zero/gemini/memory_tracer/__init__.py +++ b/colossalai/zero/gemini/memory_tracer/__init__.py @@ -1,10 +1,14 @@ -from .param_runtime_order import OrderedParamGenerator # isort:skip -from .memory_stats import MemStats # isort:skip -from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip -from .memstats_collector import MemStatsCollector # isort:skip -from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip +from .param_runtime_order import OrderedParamGenerator # isort:skip +from .memory_stats import MemStats # isort:skip +from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip +from .memstats_collector import MemStatsCollector # isort:skip +from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip __all__ = [ - 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'MemStats', - 'OrderedParamGenerator' + "AsyncMemoryMonitor", + "SyncCudaMemoryMonitor", + "MemStatsCollector", + "ChunkMemStatsCollector", + "MemStats", + "OrderedParamGenerator", ] diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index b93ad2c44104..b5e40a817e58 100644 --- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -8,7 +8,6 @@ class ChunkMemStatsCollector(MemStatsCollector): - def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: """ @@ -27,10 +26,11 @@ def record_model_data_volume(self) -> None: record model data volume on cuda and cpu. """ if self._start_flag and not self.use_outside_memstats: - cuda_mem = self._chunk_manager.total_mem['cuda'] + cuda_mem = self._chunk_manager.total_mem["cuda"] self._memstats.record_max_cuda_model_data(cuda_mem) @property def cuda_margin_mem(self) -> float: from colossalai.legacy.utils.memory import colo_device_memory_capacity + return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py index 2a65d4b55409..513a6326d5f1 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py @@ -111,6 +111,7 @@ def finish(self): def _measure_usage(self): from colossalai.legacy.utils import colo_device_memory_used + max_usage = 0 while self.keep_measuring: max_usage = max( diff --git a/colossalai/zero/gemini/memory_tracer/memory_stats.py b/colossalai/zero/gemini/memory_tracer/memory_stats.py index 02de6ecb97a9..1c141169f045 100644 --- a/colossalai/zero/gemini/memory_tracer/memory_stats.py +++ b/colossalai/zero/gemini/memory_tracer/memory_stats.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch @@ -6,7 +6,6 @@ class MemStats(object): - def __init__(self) -> None: """ Store the non model data statistics used for Gemini and GeminiOptimizer. @@ -92,17 +91,17 @@ def param_order(self): return self._param_runtime_order def non_model_data_list(self, device_type: str) -> List[int]: - if device_type == 'cuda': + if device_type == "cuda": return self._non_model_data_cuda_list - elif device_type == 'cpu': + elif device_type == "cpu": return self._non_model_data_cpu_list else: raise TypeError def max_non_model_data(self, device_type: str) -> float: - if device_type == 'cuda': + if device_type == "cuda": return max(self._non_model_data_cuda_list) - elif device_type == 'cpu': + elif device_type == "cpu": return max(self._non_model_data_cpu_list) else: raise TypeError diff --git a/colossalai/zero/gemini/memory_tracer/memstats_collector.py b/colossalai/zero/gemini/memory_tracer/memstats_collector.py index abb3dcc74b27..e4459831109a 100644 --- a/colossalai/zero/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/memstats_collector.py @@ -40,11 +40,12 @@ def next_period_non_model_data_usage(self, device_type: str) -> int: Returns: int: max non model data memory usage of current sampling period """ - assert not self._start_flag, 'Cannot get mem stats info during collection phase.' - assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' - assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ - f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ + assert not self._start_flag, "Cannot get mem stats info during collection phase." + assert self._step_total > 0, "Cannot get mem stats info before collection phase." + assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, ( + f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, " f"step total {self._step_total}" + ) next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] self._step_idx = (self._step_idx + 1) % self._step_total return next_non_model_data @@ -60,9 +61,9 @@ def start_collection(self): def finish_collection(self): self.sample_overall_data() # self._step_total = len(self._sampling_time) - self._step_total = len(self._memstats.non_model_data_list('cuda')) + self._step_total = len(self._memstats.non_model_data_list("cuda")) self._start_flag = False - print(f'finish_collection {self._step_total}') + print(f"finish_collection {self._step_total}") # deprecated def record_model_data_volume(self) -> None: @@ -73,7 +74,7 @@ def record_model_data_volume(self) -> None: from colossalai.legacy.zero.gemini import StatefulTensor # The following code work for ZeroInitContext, which is deprecated in v0.1.12 - cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] + cuda_mem = StatefulTensor.GST_MGR.total_mem["cuda"] self._memstats.record_max_cuda_model_data(cuda_mem) def sample_overall_data(self) -> None: diff --git a/colossalai/zero/gemini/memory_tracer/param_runtime_order.py b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py index 638c0533ce92..670edb9ec0d2 100644 --- a/colossalai/zero/gemini/memory_tracer/param_runtime_order.py +++ b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py @@ -4,7 +4,6 @@ class ParamGenerator(ABC): - def append(self, param: torch.nn.Parameter): pass diff --git a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py index 6656821fef74..b0d258824d2b 100644 --- a/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py @@ -10,10 +10,10 @@ from .memory_stats import MemStats -__all__ = ['RuntimeMemTracer'] +__all__ = ["RuntimeMemTracer"] -class RuntimeMemTracer(): +class RuntimeMemTracer: """RuntimeMemTracer for the module training using ColoParameter. Trace non-model memory usage during fwd+bwd process. diff --git a/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py index b8f9a095f422..2a1a3745f81c 100644 --- a/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py @@ -15,9 +15,9 @@ class ModuleInfos: - - def __init__(self, module: torch.nn.Module, module_name: str, module_full_name: str, - parent_module: torch.nn.Module): + def __init__( + self, module: torch.nn.Module, module_name: str, module_full_name: str, parent_module: torch.nn.Module + ): self.module = module self.module_name = module_name self.module_full_name = module_full_name @@ -35,14 +35,13 @@ def __init__(self, module: nn.Module, chunk_manager: ChunkManager) -> None: self.module_info_list = [] def init_mem_stats(self, *inputs): - self.register_opnodes_recursively(self.module) self.refactor_module() self.module = self.module.cpu() self.module.train() - data = [MetaTensor(torch.rand(inp.shape, device='meta'), fake_device='cpu') for inp in inputs] + data = [MetaTensor(torch.rand(inp.shape, device="meta"), fake_device="cpu") for inp in inputs] gm = symbolic_trace(self.module) interp = MetaInfoProp(gm) interp.propagate(*data) @@ -87,12 +86,13 @@ def recover_module(self): for modInfo in self.module_info_list: modInfo.parent_module.__setattr__(modInfo.module_name, modInfo.module) - def register_opnodes_recursively(self, - module: torch.nn.Module, - name: str = "", - full_name: str = "", - parent_module: Optional[torch.nn.Module] = None): - + def register_opnodes_recursively( + self, + module: torch.nn.Module, + name: str = "", + full_name: str = "", + parent_module: Optional[torch.nn.Module] = None, + ): assert isinstance(module, torch.nn.Module) for child_name, child in module.named_children(): diff --git a/colossalai/zero/gemini/memory_tracer/utils.py b/colossalai/zero/gemini/memory_tracer/utils.py index 65f6ba775139..9faf81af63d7 100644 --- a/colossalai/zero/gemini/memory_tracer/utils.py +++ b/colossalai/zero/gemini/memory_tracer/utils.py @@ -14,7 +14,7 @@ def colo_model_optimizer_usage(optim) -> Tuple[int, int]: """ if optim is None: return 0, 0 - assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()" + assert hasattr(optim, "get_memory_usage"), f"{type(optim)} has no attr get_memory_usage()" return optim.get_memory_usage() @@ -35,16 +35,16 @@ def _get_tensor_mem_use(t: Optional[torch.Tensor]): return 0, 0 assert isinstance(t, torch.Tensor) _cpu_mem_usage, _cuda_mem_usage = 0, 0 - if t.device.type == 'cpu': + if t.device.type == "cpu": _cpu_mem_usage += t.numel() * t.element_size() - elif t.device.type == 'cuda': + elif t.device.type == "cuda": _cuda_mem_usage += t.numel() * t.element_size() return _cuda_mem_usage, _cpu_mem_usage cuda_mem_usage = 0 cpu_mem_usage = 0 for param in model.parameters(): - if hasattr(param, 'colo_attr'): + if hasattr(param, "colo_attr"): t_cuda, t_cpu = param.colo_attr.get_memory_usage() cuda_mem_usage += t_cuda cpu_mem_usage += t_cpu diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index a35529723a68..8a74eb587b83 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -17,10 +17,9 @@ class PlacementPolicy(ABC): need_mem_stats: bool = False - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None, - **kwargs) -> None: + def __init__( + self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs + ) -> None: self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector @@ -29,23 +28,25 @@ def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, f raise NotImplementedError @abstractmethod - def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, - torch.device]) -> None: + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: raise NotImplementedError class StaticPlacementPolicy(PlacementPolicy): - - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None, - shard_param_frac: float = 1.0, - offload_optim_frac: float = 0.0, - offload_param_frac: float = 0.0, - **kwargs) -> None: + def __init__( + self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + shard_param_frac: float = 1.0, + offload_optim_frac: float = 0.0, + offload_param_frac: float = 0.0, + **kwargs, + ) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): - warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0') + warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 self.shard_param_frac = shard_param_frac self.offload_optim_frac = offload_optim_frac @@ -66,13 +67,14 @@ def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, f for chunk in can_evict_chunks: if can_offload_chunk_mem <= self.keep_cuda_chunk_mem: break - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + self.chunk_manager.move_chunk(chunk, torch.device("cpu")) # real saved mem is shard_mem, for simplicity we use chunk_mem can_offload_chunk_mem -= chunk.chunk_mem return 0, 0.0 - def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, - torch.device]) -> None: + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params) offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac @@ -85,7 +87,7 @@ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[ if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: device = get_current_device() else: - device = torch.device('cpu') + device = torch.device("cpu") # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here offloaded_optim_chunk_mem += chunk.chunk_mem for p in params: @@ -97,12 +99,14 @@ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[ class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None, - warmup_non_model_data_ratio: float = 0.8, - steady_cuda_cap_ratio: float = 0.9, - **kwargs) -> None: + def __init__( + self, + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + warmup_non_model_data_ratio: float = 0.8, + steady_cuda_cap_ratio: float = 0.9, + **kwargs, + ) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() @@ -110,13 +114,15 @@ def __init__(self, self._warmup_non_model_data_ratio = warmup_non_model_data_ratio self._steady_cuda_cap_ratio = steady_cuda_cap_ratio - def evict_tensors(self, - can_evict_chunks: List[Chunk], - cuda_demand: int = 0, - warmup: bool = True, - compute_list: Optional[List[Tuple[Chunk, ...]]] = None, - compute_idx: int = 0, - **kwargs) -> Tuple[int, float]: + def evict_tensors( + self, + can_evict_chunks: List[Chunk], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: Optional[List[Tuple[Chunk, ...]]] = None, + compute_idx: int = 0, + **kwargs, + ) -> Tuple[int, float]: """ Evict tensors from CUDA device. @@ -135,13 +141,13 @@ def evict_tensors(self, """ start = time() cuda_capacity = colo_device_memory_capacity(get_current_device()) - used_cuda_model_data = self.chunk_manager.total_mem['cuda'] + used_cuda_model_data = self.chunk_manager.total_mem["cuda"] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage("cuda") cuda_capacity *= self._steady_cuda_cap_ratio total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data @@ -160,11 +166,13 @@ def evict_tensors(self, break self.chunk_manager.release_chunk(chunk) - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) + self.chunk_manager.move_chunk(chunk, torch.device("cpu")) freed_cuda_model_data += chunk.chunk_mem if freed_cuda_model_data < to_free_cuda_model_data: - raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " - f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}") + raise RuntimeError( + f"Adjust layout failed! No enough CUDA memory! " + f"Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" + ) return freed_cuda_model_data, time() - start @staticmethod @@ -178,8 +186,9 @@ def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_li next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) return [t for (t, idx) in next_compute_idx] - def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, - torch.device]) -> None: + def setup_grads_device( + self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, torch.device] + ) -> None: for p in params: chunk = self.chunk_manager.get_chunk(p) # init offload optim settings @@ -187,13 +196,13 @@ def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[ if chunk.keep_gathered: grads_device_map[p] = get_current_device() else: - grads_device_map[p] = torch.device('cpu') + grads_device_map[p] = torch.device("cpu") class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { - 'auto': AutoPlacementPolicy, - 'static': StaticPlacementPolicy, + "auto": AutoPlacementPolicy, + "static": StaticPlacementPolicy, } @staticmethod diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index 0d92d32e5603..264099d22de2 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -27,16 +27,15 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk): return total_temp -def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ''): - """Get a dfs module list of the given module. Its order is same as the order of creations of modules. - """ +def _get_dfs_module_list(module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = ""): + """Get a dfs module list of the given module. Its order is same as the order of creations of modules.""" if memo is None: memo = set() if module not in memo: for name, submodule in module._modules.items(): if submodule is None: continue - submodule_prefix = prefix + ('.' if prefix else '') + name + submodule_prefix = prefix + ("." if prefix else "") + name for m in _get_dfs_module_list(submodule, memo, submodule_prefix): yield m @@ -60,10 +59,9 @@ def _get_shallow_copy_model(model: nn.Module): return old_to_new[model] -def get_static_torch_model(zero_ddp_model, - device=torch.device("cpu"), - dtype=torch.float32, - only_rank_0=True) -> torch.nn.Module: +def get_static_torch_model( + zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True +) -> torch.nn.Module: """Get a static torch.nn.Module model from the given GeminiDDP module. You should notice that the original GeminiDDP model is not modified. Thus, you can use the original model in further training. @@ -79,6 +77,7 @@ def get_static_torch_model(zero_ddp_model, torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ from colossalai.zero.gemini.gemini_ddp import GeminiDDP + assert isinstance(zero_ddp_model, GeminiDDP) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) @@ -86,15 +85,17 @@ def get_static_torch_model(zero_ddp_model, torch_model = _get_shallow_copy_model(colo_model) if not only_rank_0 or dist.get_rank() == 0: - for (name, colo_module), (_, torch_module) in \ - zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)): + for (name, colo_module), (_, torch_module) in zip( + _get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model) + ): # clean the parameter list of the new torch module torch_module._parameters = OrderedDict() for sufix_param_name, param in colo_module.named_parameters(recurse=False): # get the full name of the parameter - full_param_name = name + ('.' if name else '') + sufix_param_name - assert full_param_name in state_dict, \ - f"Can not find parameter `{full_param_name}` in the GeminiDDP module" + full_param_name = name + ("." if name else "") + sufix_param_name + assert ( + full_param_name in state_dict + ), f"Can not find parameter `{full_param_name}` in the GeminiDDP module" state_param = state_dict[full_param_name] torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype)) diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py index ae3c1de3a5bc..270a6a6a4786 100644 --- a/colossalai/zero/low_level/__init__.py +++ b/colossalai/zero/low_level/__init__.py @@ -1,3 +1,3 @@ from .low_level_optim import LowLevelZeroOptimizer -__all__ = ['LowLevelZeroOptimizer'] +__all__ = ["LowLevelZeroOptimizer"] diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index ece92fe02e28..ba1135940df0 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -44,8 +44,8 @@ def shuffle_by_round_robin(tensor_list, num_partitions): for partition_id in range(partitions_count): partition_tensors = partitions[partition_id] for item in partition_tensors: - tensor_index_mapping[item['index']] = len(new_tensor_list) - new_tensor_list.append(item['tensor']) + tensor_index_mapping[item["index"]] = len(new_tensor_list) + new_tensor_list.append(item["tensor"]) return new_tensor_list, tensor_index_mapping @@ -107,11 +107,13 @@ def split_by_dtype(tensor_list): return buckets -def reduce_tensor_dp_group(tensor: torch.Tensor, - dtype: Optional[torch.dtype] = None, - dst_local_rank: Optional[int] = None, - dst_global_rank: Optional[int] = None, - group: Optional[dist.ProcessGroup] = None): +def reduce_tensor_dp_group( + tensor: torch.Tensor, + dtype: Optional[torch.dtype] = None, + dst_local_rank: Optional[int] = None, + dst_global_rank: Optional[int] = None, + group: Optional[dist.ProcessGroup] = None, +): """ Reduce the tensor in the data parallel process group @@ -173,7 +175,7 @@ def has_inf_or_nan(tensor): raise return True else: - if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: + if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum: return True return False @@ -184,8 +186,7 @@ def release_param_grad(tensor_list): def calculate_global_norm_from_list(norm_list): - """ Compute total from a list of norms - """ + """Compute total from a list of norms""" total_norm = 0.0 for norm in norm_list: total_norm += norm**2.0 @@ -221,7 +222,7 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro total_norm = 0.0 for g in gradients: param_norm = g.data.double().norm(2) - total_norm += param_norm.item()**2 + total_norm += param_norm.item() ** 2 # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) @@ -230,9 +231,9 @@ def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGro if tp_group is not None: dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: + if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm: total_norm = -1 return total_norm diff --git a/colossalai/zero/low_level/bookkeeping/__init__.py b/colossalai/zero/low_level/bookkeeping/__init__.py index 7bcacfabfded..427973772f9c 100644 --- a/colossalai/zero/low_level/bookkeeping/__init__.py +++ b/colossalai/zero/low_level/bookkeeping/__init__.py @@ -3,4 +3,4 @@ from .parameter_store import ParameterStore from .tensor_bucket import TensorBucket -__all__ = ['GradientStore', 'ParameterStore', 'BucketStore', 'TensorBucket'] +__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"] diff --git a/colossalai/zero/low_level/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py index 2ebd122464f4..107d62dcbc0e 100644 --- a/colossalai/zero/low_level/bookkeeping/base_store.py +++ b/colossalai/zero/low_level/bookkeeping/base_store.py @@ -3,7 +3,6 @@ class BaseStore: - def __init__(self, torch_pg: ProcessGroup): self._world_size = dist.get_world_size(group=torch_pg) self._local_rank = dist.get_rank(group=torch_pg) diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 0ab10e25d407..2a75d704711a 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -9,7 +9,6 @@ class BucketStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) @@ -38,8 +37,7 @@ def num_elements_in_bucket(self) -> int: return self._num_elements_in_bucket def reset_num_elements_in_bucket(self): - """Set the number of elements in bucket to zero. - """ + """Set the number of elements in bucket to zero.""" self._num_elements_in_bucket = 0 @@ -54,7 +52,7 @@ def add_param_grad(self, group_id: int, param: Tensor, padding_size: int): self._param_list.append(param) self._padding_size.append(padding_size) - self._num_elements_in_bucket += (param.numel() + padding_size) + self._num_elements_in_bucket += param.numel() + padding_size self.current_group_id = group_id # number of tensors in current bucket @@ -119,8 +117,7 @@ def get_param_id_of_grad(self, grad: Tensor) -> int: return self.grad_to_param_mapping[id(grad)] def reset(self): - """Reset the bucket storage after reduction, only release the tensors have been reduced - """ + """Reset the bucket storage after reduction, only release the tensors have been reduced""" cur_offset = self.offset_list.pop(0) self._param_list = self._param_list[cur_offset:] self._padding_size = self._padding_size[cur_offset:] diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 2890b329a642..3ce688cfa930 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -1,13 +1,11 @@ from typing import List from torch import Tensor -from torch._utils import _flatten_dense_tensors from .base_store import BaseStore class GradientStore(BaseStore): - def __init__(self, *args, partition_grad: bool = False): super().__init__(*args) """ diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py index 63f7c5506069..e94fb4de9b9f 100644 --- a/colossalai/zero/low_level/bookkeeping/parameter_store.py +++ b/colossalai/zero/low_level/bookkeeping/parameter_store.py @@ -5,7 +5,6 @@ class ParameterStore(BaseStore): - def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index b32816a046cd..16ba8a6d6445 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -2,7 +2,6 @@ class TensorBucket: - def __init__(self, size): self._max_size = size self._current_size = 0 @@ -26,8 +25,7 @@ def add_to_bucket(self, tensor, allow_oversize=False): tensor_size = tensor.numel() if not allow_oversize and self.will_exceed_max_size(tensor_size): - msg = f"The param bucket max size {self._max_size} is exceeded" \ - + f"by tensor (size {tensor_size})" + msg = f"The param bucket max size {self._max_size} is exceeded" + f"by tensor (size {tensor_size})" raise RuntimeError(msg) self._bucket.append(tensor) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 0bdd6a3e2370..1bf5302efcfb 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -17,6 +17,7 @@ ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger + # from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.utils.cuda import get_current_device @@ -32,19 +33,21 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): - - def __init__(self, - num_working_param_groups: int, - grad_store: GradientStore, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32) -> None: - super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, - max_scale) + def __init__( + self, + num_working_param_groups: int, + grad_store: GradientStore, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + ) -> None: + super().__init__( + initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale + ) self.num_working_param_groups = num_working_param_groups self.grad_store = grad_store @@ -57,32 +60,31 @@ def check_local_overflow(self) -> bool: class LowLevelZeroOptimizer(OptimizerWrapper): - """Optimizer used for ZeRO-1 and ZeRO-2. - """ + """Optimizer used for ZeRO-1 and ZeRO-2.""" def __init__( - self, - optimizer: Optimizer, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2., - backoff_factor: float = .5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp - forced_dtype: Optional[torch.dtype] = None): - + self, + optimizer: Optimizer, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None, + ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) - self._dtype = self.optim.param_groups[0]['params'][0].dtype + self._dtype = self.optim.param_groups[0]["params"][0].dtype self._logger = get_dist_logger() self._verbose = verbose @@ -115,7 +117,7 @@ def __init__( if forced_dtype: for group in self.optim.param_groups: - group_params = group['params'] + group_params = group["params"] for param in group_params: param.data = param.data.to(forced_dtype) self._dtype = forced_dtype @@ -134,7 +136,7 @@ def __init__( # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): group_params = list() - for param in param_group['params']: + for param in param_group["params"]: if param.requires_grad: group_params.append(param) @@ -148,7 +150,7 @@ def __init__( # need to replace the params in the `params` field in the optimizer # so that when the optimizer calls step(), it only updates the tensors # managed by this data parallel rank - param_group['params'] = master_param_current_rank + param_group["params"] = master_param_current_rank # intialize communication stream for # communication-compuation overlapping @@ -164,15 +166,17 @@ def __init__( # initialize mixed precision mixin self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None if self._dtype is torch.float16: - self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups, - self._grad_store, - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) + self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin( + self.num_param_groups, + self._grad_store, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) elif self._dtype is torch.bfloat16: self.mixed_precision_mixin = BF16MixedPrecisionMixin() @@ -185,17 +189,18 @@ def num_param_groups(self): return len(self._working_param_groups) def _sanity_checks(self): - assert torch.cuda.is_available(), 'CUDA is required' + assert torch.cuda.is_available(), "CUDA is required" for param_group in self.optim.param_groups: - group_params = param_group['params'] + group_params = param_group["params"] for param in group_params: - assert param.dtype == self._dtype, \ - f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" def _create_master_param_current_rank(self, param_list): # split each param evenly by world size params_current_rank = [] - device = 'cpu' if self._cpu_offload else get_current_device() + device = "cpu" if self._cpu_offload else get_current_device() for param in param_list: padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size @@ -275,8 +280,10 @@ def _run_reduction(self): sync_tensor(flat_grads_per_rank[rank], grad_list) for grad in grad_list: param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, - param_id)) < self._world_size: + if ( + len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) + < self._world_size + ): self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) else: self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) @@ -307,8 +314,10 @@ def _add_to_bucket(self, param, group_id): # if full, will reduce the grads already in the bucket # or got a grad of param from another group # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \ - group_id != self._bucket_store.current_group_id: + if ( + self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size + or group_id != self._bucket_store.current_group_id + ): self._run_reduction() padding_size = self._param_store.get_param_padding_size(param) @@ -319,8 +328,9 @@ def _add_to_bucket(self, param, group_id): ################################ def backward(self, loss, retain_graph=False): - assert not(self._partition_grads and not self.require_grad_sync), \ - "ZeRO2(partition_grads) and no_sync are not compatible" + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and no_sync are not compatible" if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) @@ -339,8 +349,9 @@ def backward(self, loss, retain_graph=False): self.zero_grad() def backward_by_grad(self, tensor, grad): - assert not(self._partition_grads and not self.require_grad_sync), \ - "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + assert not ( + self._partition_grads and not self.require_grad_sync + ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" if self.mixed_precision_mixin is not None: grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad) @@ -380,14 +391,14 @@ def zero_grad(self, set_to_none=True): #################### def step(self, closure=None): - assert closure is None, 'closure is not supported by step()' + assert closure is None, "closure is not supported by step()" if not self.require_grad_sync: return if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): self._grad_store.reset_all_gradients() if self._verbose: - self._logger.info(f'Found overflow. Skip step') + self._logger.info(f"Found overflow. Skip step") self.zero_grad() return @@ -428,7 +439,7 @@ def step(self, closure=None): self._grad_store.reset_grads_by_group_id(group_id) # update the params in the optimizer - self.optim.param_groups[group_id]['params'] = real_master_params[group_id] + self.optim.param_groups[group_id]["params"] = real_master_params[group_id] # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) @@ -445,16 +456,16 @@ def step(self, closure=None): # update working partition updated by the current rank dtype = real_working_params[0][0].dtype for group_id in range(self.num_param_groups): - master_working_param = self.optim.param_groups[group_id]['params'] + master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] all_splited_param = [ torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) ] dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) - working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param)) + working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) - self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] + self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] ############################# # Mixed Precision Utilities # @@ -466,14 +477,14 @@ def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): if self.mixed_precision_mixin is not None: div_scale = self.mixed_precision_mixin.get_grad_div_scale() - if self._clip_grad_norm > 0.: + if self._clip_grad_norm > 0.0: # norm is in fact norm*scale clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm if clip > 1: div_scale = clip * div_scale for grad in grad_groups_flat: - grad.data.mul_(1. / div_scale) + grad.data.mul_(1.0 / div_scale) ############################ # Gradient Synchronization # @@ -518,18 +529,19 @@ def _pack_state(self, state: Dict) -> Dict: def pack_group(group): nonlocal start_index - packed = {k: v for k, v in group.items() if k != 'params'} + packed = {k: v for k, v in group.items() if k != "params"} param_mappings.update( - {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) - packed['params'] = [param_mappings[id(p)] for p in group['params']] - start_index += len(packed['params']) + {id(p): i for i, p in enumerate(group["params"], start_index) if id(p) not in param_mappings} + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) return packed param_groups = [pack_group(g) for g in self.optim.param_groups] # Remap state to use order indices as keys packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()} - return {'state': packed_state, 'param_groups': param_groups} + return {"state": packed_state, "param_groups": param_groups} def state_dict(self) -> Dict: """Return a state_dict same with DDP @@ -541,14 +553,15 @@ def state_dict(self) -> Dict: for param, state in self.optim.state.items(): zero_state[param] = copy.deepcopy(state) for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != 'step': + if isinstance(v, torch.Tensor) and k != "step": working_param = self._param_store.master_to_working_param[id(param)] gather_tensor = [ - torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) ] dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) - param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( - working_param).cpu() + param_state = ( + torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) zero_state[param][k] = param_state states_dict = self._pack_state(zero_state) @@ -562,16 +575,16 @@ def load_state_dict(self, state_dict: Dict): state_dict (dict): A pytorch form state_dict """ zero_state_dict = copy.deepcopy(state_dict) - for param_idx, state in zero_state_dict['state'].items(): + for param_idx, state in zero_state_dict["state"].items(): for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != 'step': + if isinstance(v, torch.Tensor) and k != "step": padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size with torch.no_grad(): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) v_list = v.split(v.numel() // self._world_size) - zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone() + zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -588,7 +601,7 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i ret_block = dict() ret_block_size = 0 - local_states = self.optim.state_dict()['state'] + local_states = self.optim.state_dict()["state"] for param_idx, states in local_states.items(): current_block_size = 0 current_block = copy.deepcopy(states) @@ -601,11 +614,12 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i working_param = self._param_store.master_to_working_param[id(master_param)] for k, v in states.items(): - if isinstance(v, torch.Tensor) and k != 'step': - state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)] + if isinstance(v, torch.Tensor) and k != "step": + state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)] dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) - state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as( - working_param).cpu() + state_tensor = ( + torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() + ) current_block_size += state_tensor.numel() current_block[k] = state_tensor diff --git a/colossalai/zero/wrapper.py b/colossalai/zero/wrapper.py index 90325fe0a704..ed873254e301 100644 --- a/colossalai/zero/wrapper.py +++ b/colossalai/zero/wrapper.py @@ -7,10 +7,9 @@ from .gemini import GeminiDDP -def zero_model_wrapper(model: nn.Module, - zero_stage: int = 1, - gemini_config: Optional[Dict] = None, - verbose: bool = False): +def zero_model_wrapper( + model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None, verbose: bool = False +): """This wrapper function is used to wrap your training model for ZeRO DDP. Example: @@ -50,19 +49,21 @@ def zero_model_wrapper(model: nn.Module, return wrapped_model -def zero_optim_wrapper(model: nn.Module, - optimizer: torch.optim.Optimizer, - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0, - optim_config: Optional[Dict] = None, - verbose: bool = False): +def zero_optim_wrapper( + model: nn.Module, + optimizer: torch.optim.Optimizer, + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + optim_config: Optional[Dict] = None, + verbose: bool = False, +): """This wrapper function is used to wrap your training optimizer for ZeRO DDP. Args: @@ -95,20 +96,22 @@ def zero_optim_wrapper(model: nn.Module, else: config_dict = copy(optim_config) - config_dict['initial_scale'] = initial_scale - config_dict['growth_factor'] = growth_factor - config_dict['backoff_factor'] = backoff_factor - config_dict['growth_interval'] = growth_interval - config_dict['hysteresis'] = hysteresis - config_dict['min_scale'] = min_scale - config_dict['max_scale'] = max_scale + config_dict["initial_scale"] = initial_scale + config_dict["growth_factor"] = growth_factor + config_dict["backoff_factor"] = backoff_factor + config_dict["growth_interval"] = growth_interval + config_dict["hysteresis"] = hysteresis + config_dict["min_scale"] = min_scale + config_dict["max_scale"] = max_scale if zero_stage in [1, 2]: from colossalai.zero.low_level import LowLevelZeroOptimizer - config_dict['partition_grad'] = zero_stage == 2 - config_dict['clip_grad_norm'] = max_norm + + config_dict["partition_grad"] = zero_stage == 2 + config_dict["clip_grad_norm"] = max_norm return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose) else: from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer - config_dict['clipping_norm'] = max_norm + + config_dict["clipping_norm"] = max_norm return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose) diff --git a/examples/community/fp8/mnist/main.py b/examples/community/fp8/mnist/main.py index a534663d380f..2bb912dec247 100644 --- a/examples/community/fp8/mnist/main.py +++ b/examples/community/fp8/mnist/main.py @@ -13,13 +13,13 @@ try: from transformer_engine import pytorch as te + HAVE_TE = True except (ImportError, ModuleNotFoundError): HAVE_TE = False class Net(nn.Module): - def __init__(self, use_te=False): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) @@ -64,10 +64,12 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: - print(f"Train Epoch: {epoch} " - f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " - f"({100. * batch_idx / len(train_loader):.0f}%)]\t" - f"Loss: {loss.item():.6f}") + print( + f"Train Epoch: {epoch} " + f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " + f"({100. * batch_idx / len(train_loader):.0f}%)]\t" + f"Loss: {loss.item():.6f}" + ) if args.dry_run: break @@ -75,13 +77,11 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): def calibrate(model, device, test_loader): """Calibration function.""" model.eval() - test_loss = 0 - correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) with te.fp8_autocast(enabled=False, calibrating=True): - output = model(data) + model(data) def test(model, device, test_loader, use_fp8): @@ -94,15 +94,17 @@ def test(model, device, test_loader, use_fp8): data, target = data.to(device), target.to(device) with te.fp8_autocast(enabled=use_fp8): output = model(data) - test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print(f"\nTest set: Average loss: {test_loss:.4f}, " - f"Accuracy: {correct}/{len(test_loader.dataset)} " - f"({100. * correct / len(test_loader.dataset):.0f}%)\n") + print( + f"\nTest set: Average loss: {test_loss:.4f}, " + f"Accuracy: {correct}/{len(test_loader.dataset)} " + f"({100. * correct / len(test_loader.dataset):.0f}%)\n" + ) def main(): @@ -163,10 +165,9 @@ def main(): default=False, help="For Saving the current Model", ) - parser.add_argument("--use-fp8", - action="store_true", - default=False, - help="Use FP8 for inference and training without recalibration") + parser.add_argument( + "--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" + ) parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only") parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine") args = parser.parse_args() @@ -215,7 +216,7 @@ def main(): if args.save_model or args.use_fp8_infer: torch.save(model.state_dict(), "mnist_cnn.pt") - print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer)) + print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer)) weights = torch.load("mnist_cnn.pt") model.load_state_dict(weights) test(model, device, test_loader, args.use_fp8_infer) diff --git a/examples/community/roberta/preprocessing/get_mask.py b/examples/community/roberta/preprocessing/get_mask.py index 74c97a63a9f3..f0ba8fe38501 100644 --- a/examples/community/roberta/preprocessing/get_mask.py +++ b/examples/community/roberta/preprocessing/get_mask.py @@ -1,13 +1,8 @@ import collections import logging -import os import random -import time -from enum import IntEnum -from random import choice import jieba -import torch jieba.setLogLevel(logging.CRITICAL) import re @@ -23,14 +18,15 @@ def map_to_numpy(data): return np.asarray(data) -class PreTrainingDataset(): - - def __init__(self, - tokenizer, - max_seq_length, - backend='python', - max_predictions_per_seq: int = 80, - do_whole_word_mask: bool = True): +class PreTrainingDataset: + def __init__( + self, + tokenizer, + max_seq_length, + backend="python", + max_predictions_per_seq: int = 80, + do_whole_word_mask: bool = True, + ): self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.masked_lm_prob = 0.15 @@ -38,8 +34,8 @@ def __init__(self, self.do_whole_word_mask = do_whole_word_mask self.max_predictions_per_seq = max_predictions_per_seq self.vocab_words = list(tokenizer.vocab.keys()) - self.rec = re.compile('[\u4E00-\u9FA5]') - self.whole_rec = re.compile('##[\u4E00-\u9FA5]') + self.rec = re.compile("[\u4E00-\u9FA5]") + self.whole_rec = re.compile("##[\u4E00-\u9FA5]") self.mlm_p = 0.15 self.mlm_mask_p = 0.8 @@ -64,7 +60,7 @@ def create_training_instance(self, instance): original_tokens = [] segment_ids = [] tokens.append("[CLS]") - original_tokens.append('[CLS]') + original_tokens.append("[CLS]") segment_ids.append(0) for index, token in enumerate(tokens_a): tokens.append(token) @@ -72,7 +68,7 @@ def create_training_instance(self, instance): segment_ids.append(0) tokens.append("[SEP]") - original_tokens.append('[SEP]') + original_tokens.append("[SEP]") segment_ids.append(0) # for token in tokens_b: @@ -83,11 +79,16 @@ def create_training_instance(self, instance): # segment_ids.append(1) # Get Masked LM predictions - if self.backend == 'c++': + if self.backend == "c++": output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions( - tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq, - self.masked_lm_prob) - elif self.backend == 'python': + tokens, + original_tokens, + self.vocab_words, + self.tokenizer.vocab, + self.max_predictions_per_seq, + self.masked_lm_prob, + ) + elif self.backend == "python": output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) # Convert to Ids @@ -99,20 +100,20 @@ def create_training_instance(self, instance): segment_ids.append(PAD) input_mask.append(PAD) masked_lm_output.append(-1) - return ([ + return [ map_to_numpy(input_ids), map_to_numpy(input_mask), map_to_numpy(segment_ids), map_to_numpy(masked_lm_output), - map_to_numpy([is_next]) - ]) + map_to_numpy([is_next]), + ] def create_masked_lm_predictions(self, tokens): cand_indexes = [] for i, token in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue - if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): + if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"): cand_indexes[-1].append(i) else: cand_indexes.append([i]) @@ -160,7 +161,7 @@ def get_new_segment(self, segment): Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word. :param segment: a sentence """ - seq_cws = jieba.lcut(''.join(segment)) + seq_cws = jieba.lcut("".join(segment)) seq_cws_dict = {x: 1 for x in seq_cws} new_segment = [] i = 0 @@ -174,10 +175,10 @@ def get_new_segment(self, segment): for length in range(3, 0, -1): if i + length > len(segment): continue - if ''.join(segment[i:i + length]) in seq_cws_dict: + if "".join(segment[i : i + length]) in seq_cws_dict: new_segment.append(segment[i]) for l in range(1, length): - new_segment.append('##' + segment[i + l]) + new_segment.append("##" + segment[i + l]) i += length has_add = True break @@ -190,7 +191,7 @@ def create_whole_masked_lm_predictions(self, tokens): """Creates the predictions for the masked LM objective.""" cand_indexes = [] - for (i, token) in enumerate(tokens): + for i, token in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue # Whole Word Masking means that if we mask all of the wordpieces @@ -202,14 +203,14 @@ def create_whole_masked_lm_predictions(self, tokens): # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. - if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): + if self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##"): cand_indexes[-1].append(i) else: cand_indexes.append([i]) random.shuffle(cand_indexes) - output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##" + output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##" num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob)))) @@ -239,8 +240,9 @@ def create_whole_masked_lm_predictions(self, tokens): else: # 10% of the time, keep original if random.random() < 0.5: - masked_token = tokens[index][2:] if len(self.whole_rec.findall( - tokens[index])) > 0 else tokens[index] # 去掉"##" + masked_token = ( + tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index] + ) # 去掉"##" # 10% of the time, replace with random word else: masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] @@ -250,7 +252,9 @@ def create_whole_masked_lm_predictions(self, tokens): masked_lms.append( MaskedLMInstance( index=index, - label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index])) + label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index], + ) + ) assert len(masked_lms) <= num_to_predict masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_output = [-1] * len(output_tokens) diff --git a/examples/community/roberta/preprocessing/sentence_split.py b/examples/community/roberta/preprocessing/sentence_split.py index 76e8bd428723..8c83ce095582 100644 --- a/examples/community/roberta/preprocessing/sentence_split.py +++ b/examples/community/roberta/preprocessing/sentence_split.py @@ -14,17 +14,19 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s sent_list = [] try: if flag == "zh": - document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) - document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) + document = re.sub("(?P([。?!…](?![”’\"'])))", r"\g\n", document) + document = re.sub("(?P([。?!]|…{1,2})[”’\"'])", r"\g\n", document) elif flag == "en": - document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) - document = re.sub('(?P([?!.]["\']))', r'\g\n', - document) # Special quotation marks + document = re.sub("(?P([.?!](?![”’\"'])))", r"\g\n", document) + document = re.sub( + "(?P([?!.][\"']))", r"\g\n", document + ) # Special quotation marks else: - document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) + document = re.sub("(?P([。?!….?!](?![”’\"'])))", r"\g\n", document) - document = re.sub('(?P(([。?!.!?]|…{1,2})[”’"\']))', r'\g\n', - document) # Special quotation marks + document = re.sub( + "(?P(([。?!.!?]|…{1,2})[”’\"']))", r"\g\n", document + ) # Special quotation marks sent_list_ori = document.splitlines() for sent in sent_list_ori: @@ -46,36 +48,35 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None: - workers = 32 - if input_path[-1] == '/': + if input_path[-1] == "/": input_path = input_path[:-1] - cur_path = os.path.join(output_path, str(host) + '.txt') + cur_path = os.path.join(output_path, str(host) + ".txt") new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2) - with open(cur_path, 'w', encoding='utf-8') as f: + with open(cur_path, "w", encoding="utf-8") as f: for fi, fin_path in enumerate(fin_list): if not os.path.exists(os.path.join(input_path, fin_path[0])): continue - if '.json' not in fin_path[0]: + if ".json" not in fin_path[0]: continue print("Processing ", fin_path[0], " ", fi) - with open(os.path.join(input_path, fin_path[0]), 'r') as fin: - f_data = [l['content'] for l in json.load(fin)] + with open(os.path.join(input_path, fin_path[0]), "r") as fin: + f_data = [l["content"] for l in json.load(fin)] pool = multiprocessing.Pool(workers) all_sent = pool.imap_unordered(new_split_sentence, f_data, 32) pool.close() - print('finished..') + print("finished..") cnt = 0 for d in tqdm(all_sent): for i in d: - f.write(i.strip() + '\n') - f.write(']]' + '\n') + f.write(i.strip() + "\n") + f.write("]]" + "\n") cnt += 1 # if cnt >= 2: # exit() @@ -86,7 +87,7 @@ def getFileSize(filepath, shard): for i in os.listdir(filepath): all_data.append(os.path.join(filepath, i)) all_size = sum([os.path.getsize(os.path.join(filepath, f)) for f in all_data]) - ans = [[f.split('/')[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] + ans = [[f.split("/")[-1], os.path.getsize(os.path.join(filepath, f))] for f in all_data] ans = sorted(ans, key=lambda x: x[1], reverse=True) per_size = all_size / shard real_shard = [] @@ -106,24 +107,24 @@ def getFileSize(filepath, shard): return real_shard -def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): +def get_start_end(real_shard, base=0, server_num=10, server_name="GPU"): import socket + host = int(socket.gethostname().split(server_name)[-1]) fin_list = real_shard[server_num * base + host - 1] print(fin_list) - print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}') + print(f"I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}") return fin_list, host -if __name__ == '__main__': - +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--server_num', type=int, default=10, help='number of servers') - parser.add_argument('--seq_len', type=int, default=512, help='sequence length') - parser.add_argument('--shard', type=int, default=100, help='number of shards, e.g., 10, 50, or 100') - parser.add_argument('--input_path', type=str, required=True, help='input path of original corpus') - parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence') + parser.add_argument("--server_num", type=int, default=10, help="number of servers") + parser.add_argument("--seq_len", type=int, default=512, help="sequence length") + parser.add_argument("--shard", type=int, default=100, help="number of shards, e.g., 10, 50, or 100") + parser.add_argument("--input_path", type=str, required=True, help="input path of original corpus") + parser.add_argument("--output_path", type=str, required=True, help="output path of shard which has split sentence") args = parser.parse_args() server_num = args.server_num @@ -137,7 +138,7 @@ def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): start = time.time() for index, shard in enumerate(real_shard): get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len) - print(f'cost {str(time.time() - start)}') + print(f"cost {str(time.time() - start)}") # if you have multiple server, you can use code below or modify code to openmpi diff --git a/examples/community/roberta/preprocessing/tokenize_mask.py b/examples/community/roberta/preprocessing/tokenize_mask.py index f3d49c3d965f..19dbaf5384de 100644 --- a/examples/community/roberta/preprocessing/tokenize_mask.py +++ b/examples/community/roberta/preprocessing/tokenize_mask.py @@ -1,7 +1,6 @@ import argparse import multiprocessing import os -import socket import time from random import shuffle @@ -29,8 +28,7 @@ def get_raw_instance(document, max_sequence_length=512): curr_seq = [] sz_idx = 0 while sz_idx < len(sizes): - - if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: + if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: curr_seq += document[sz_idx] sz_idx += 1 elif sizes[sz_idx] >= max_sequence_length_allowed: @@ -43,7 +41,7 @@ def get_raw_instance(document, max_sequence_length=512): result_list.append(curr_seq) curr_seq = [] - if len(curr_seq) > max_sequence_length_allowed / 2: # /2 + if len(curr_seq) > max_sequence_length_allowed / 2: # /2 result_list.append(curr_seq) # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 @@ -58,33 +56,30 @@ def get_raw_instance(document, max_sequence_length=512): def split_numpy_chunk(path, tokenizer, pretrain_data, host): - documents = [] instances = [] s = time.time() - with open(path, encoding='utf-8') as fd: + with open(path, encoding="utf-8") as fd: document = [] for i, line in enumerate(tqdm(fd)): line = line.strip() # document = line # if len(document.split("")) <= 3: # continue - if len(line) > 0 and line[:2] == "]]": # This is end of document + if len(line) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: document.append(line) if len(document) > 0: documents.append(document) - print('read_file ', time.time() - s) + print("read_file ", time.time() - s) # documents = [x for x in documents if x] # print(len(documents)) # print(len(documents[0])) # print(documents[0][0:10]) - import multiprocessing - from typing import List ans = [] for docs in tqdm(documents): @@ -98,7 +93,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): instances.extend(raw_ins) del ans - print('len instance', len(instances)) + print("len instance", len(instances)) sen_num = len(instances) seq_len = 512 @@ -114,7 +109,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): segment_ids[index] = mask_dict[2] masked_lm_output[index] = mask_dict[3] - with h5py.File(f'/output/{host}.h5', 'w') as hf: + with h5py.File(f"/output/{host}.h5", "w") as hf: hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_mask", data=input_ids) hf.create_dataset("segment_ids", data=segment_ids) @@ -124,45 +119,44 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name): - - if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): - print(f'{file_name}.h5 exists') + if os.path.exists(os.path.join(output_path, f"{file_name}.h5")): + print(f"{file_name}.h5 exists") return documents = [] instances = [] s = time.time() - with open(input_path, 'r', encoding='utf-8') as fd: + with open(input_path, "r", encoding="utf-8") as fd: document = [] for i, line in enumerate(tqdm(fd)): line = line.strip() - if len(line) > 0 and line[:2] == "]]": # This is end of document + if len(line) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: document.append(line) if len(document) > 0: documents.append(document) - print(f'read_file cost {time.time() - s}, length is {len(documents)}') + print(f"read_file cost {time.time() - s}, length is {len(documents)}") ans = [] s = time.time() pool = multiprocessing.Pool(worker) encoded_doc = pool.imap_unordered(pretrain_data.tokenize, documents, 100) - for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour='cyan'): + for index, res in tqdm(enumerate(encoded_doc, start=1), total=len(documents), colour="cyan"): ans.append(res) pool.close() print((time.time() - s) / 60) del documents instances = [] - for a in tqdm(ans, colour='MAGENTA'): + for a in tqdm(ans, colour="MAGENTA"): raw_ins = get_raw_instance(a, max_sequence_length=seq_len) instances.extend(raw_ins) del ans - print('len instance', len(instances)) + print("len instance", len(instances)) new_instances = [] for _ in range(dupe_factor): @@ -171,7 +165,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ shuffle(new_instances) instances = new_instances - print('after dupe_factor, len instance', len(instances)) + print("after dupe_factor, len instance", len(instances)) sentence_num = len(instances) input_ids = np.zeros([sentence_num, seq_len], dtype=np.int32) @@ -182,7 +176,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ s = time.time() pool = multiprocessing.Pool(worker) encoded_docs = pool.imap_unordered(pretrain_data.create_training_instance, instances, 32) - for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour='blue'): + for index, mask_dict in tqdm(enumerate(encoded_docs), total=len(instances), colour="blue"): input_ids[index] = mask_dict[0] input_mask[index] = mask_dict[1] segment_ids[index] = mask_dict[2] @@ -190,7 +184,7 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ pool.close() print((time.time() - s) / 60) - with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: + with h5py.File(os.path.join(output_path, f"{file_name}.h5"), "w") as hf: hf.create_dataset("input_ids", data=input_ids) hf.create_dataset("input_mask", data=input_mask) hf.create_dataset("segment_ids", data=segment_ids) @@ -199,50 +193,48 @@ def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_ del instances -if __name__ == '__main__': - +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') - parser.add_argument('--seq_len', type=int, default=512, help='sequence length') - parser.add_argument('--max_predictions_per_seq', - type=int, - default=80, - help='number of shards, e.g., 10, 50, or 100') - parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') - parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') - parser.add_argument('--backend', - type=str, - default='python', - help='backend of mask token, python, c++, numpy respectively') + parser.add_argument("--tokenizer_path", type=str, required=True, default=10, help="path of tokenizer") + parser.add_argument("--seq_len", type=int, default=512, help="sequence length") + parser.add_argument( + "--max_predictions_per_seq", type=int, default=80, help="number of shards, e.g., 10, 50, or 100" + ) + parser.add_argument("--input_path", type=str, required=True, help="input path of shard which has split sentence") + parser.add_argument("--output_path", type=str, required=True, help="output path of h5 contains token id") + parser.add_argument( + "--backend", type=str, default="python", help="backend of mask token, python, c++, numpy respectively" + ) parser.add_argument( - '--dupe_factor', + "--dupe_factor", type=int, default=1, - help='specifies how many times the preprocessor repeats to create the input from the same article/document') - parser.add_argument('--worker', type=int, default=32, help='number of process') - parser.add_argument('--server_num', type=int, default=10, help='number of servers') + help="specifies how many times the preprocessor repeats to create the input from the same article/document", + ) + parser.add_argument("--worker", type=int, default=32, help="number of process") + parser.add_argument("--server_num", type=int, default=10, help="number of servers") args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) - pretrain_data = PreTrainingDataset(tokenizer, - args.seq_len, - args.backend, - max_predictions_per_seq=args.max_predictions_per_seq) + pretrain_data = PreTrainingDataset( + tokenizer, args.seq_len, args.backend, max_predictions_per_seq=args.max_predictions_per_seq + ) data_len = len(os.listdir(args.input_path)) for i in range(data_len): - input_path = os.path.join(args.input_path, f'{i}.txt') + input_path = os.path.join(args.input_path, f"{i}.txt") if os.path.exists(input_path): start = time.time() - print(f'process {input_path}') - split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, - args.seq_len, i) + print(f"process {input_path}") + split_numpy_chunk_pool( + input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, args.seq_len, i + ) end_ = time.time() - print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)) - print(f'has cost {(end_ - start) / 60}') - print('-' * 100) - print('') + print("memory:%.4f GB" % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)) + print(f"has cost {(end_ - start) / 60}") + print("-" * 100) + print("") # if you have multiple server, you can use code below or modify code to openmpi diff --git a/examples/community/roberta/pretraining/arguments.py b/examples/community/roberta/pretraining/arguments.py index e0702ceb59b0..35b809d80947 100644 --- a/examples/community/roberta/pretraining/arguments.py +++ b/examples/community/roberta/pretraining/arguments.py @@ -1,8 +1,6 @@ -from numpy import require - import colossalai -__all__ = ['parse_args'] +__all__ = ["parse_args"] def parse_args(): @@ -11,7 +9,7 @@ def parse_args(): parser.add_argument( "--distplan", type=str, - default='CAI_Gemini', + default="CAI_Gemini", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", ) parser.add_argument( @@ -23,65 +21,66 @@ def parse_args(): parser.add_argument( "--placement", type=str, - default='cpu', + default="cpu", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", ) parser.add_argument( "--shardinit", - action='store_true', - help= - "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + action="store_true", + help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", ) - parser.add_argument('--lr', type=float, required=True, help='initial learning rate') - parser.add_argument('--epoch', type=int, required=True, help='number of epoch') - parser.add_argument('--data_path_prefix', type=str, required=True, help="location of the train data corpus") - parser.add_argument('--eval_data_path_prefix', - type=str, - required=True, - help='location of the evaluation data corpus') - parser.add_argument('--tokenizer_path', type=str, required=True, help='location of the tokenizer') - parser.add_argument('--max_seq_length', type=int, default=512, help='sequence length') - parser.add_argument('--refresh_bucket_size', - type=int, - default=1, - help="This param makes sure that a certain task is repeated for this time steps to \ - optimize on the back propagation speed with APEX's DistributedDataParallel") - parser.add_argument("--max_predictions_per_seq", - "--max_pred", - default=80, - type=int, - help="The maximum number of masked tokens in a sequence to be predicted.") + parser.add_argument("--lr", type=float, required=True, help="initial learning rate") + parser.add_argument("--epoch", type=int, required=True, help="number of epoch") + parser.add_argument("--data_path_prefix", type=str, required=True, help="location of the train data corpus") + parser.add_argument( + "--eval_data_path_prefix", type=str, required=True, help="location of the evaluation data corpus" + ) + parser.add_argument("--tokenizer_path", type=str, required=True, help="location of the tokenizer") + parser.add_argument("--max_seq_length", type=int, default=512, help="sequence length") + parser.add_argument( + "--refresh_bucket_size", + type=int, + default=1, + help="This param makes sure that a certain task is repeated for this time steps to \ + optimize on the back propagation speed with APEX's DistributedDataParallel", + ) + parser.add_argument( + "--max_predictions_per_seq", + "--max_pred", + default=80, + type=int, + help="The maximum number of masked tokens in a sequence to be predicted.", + ) parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="accumulation_steps") parser.add_argument("--train_micro_batch_size_per_gpu", default=2, type=int, required=True, help="train batch size") parser.add_argument("--eval_micro_batch_size_per_gpu", default=2, type=int, required=True, help="eval batch size") parser.add_argument("--num_workers", default=8, type=int, help="") - parser.add_argument("--async_worker", action='store_true', help="") + parser.add_argument("--async_worker", action="store_true", help="") parser.add_argument("--bert_config", required=True, type=str, help="location of config.json") - parser.add_argument("--wandb", action='store_true', help="use wandb to watch model") - parser.add_argument("--wandb_project_name", default='roberta', help="wandb project name") + parser.add_argument("--wandb", action="store_true", help="use wandb to watch model") + parser.add_argument("--wandb_project_name", default="roberta", help="wandb project name") parser.add_argument("--log_interval", default=100, type=int, help="report interval") parser.add_argument("--log_path", type=str, required=True, help="log file which records train step") parser.add_argument("--tensorboard_path", type=str, required=True, help="location of tensorboard file") - parser.add_argument("--colossal_config", - type=str, - required=True, - help="colossal config, which contains zero config and so on") - parser.add_argument("--ckpt_path", - type=str, - required=True, - help="location of saving checkpoint, which contains model and optimizer") - parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") - parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug") - parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoint") parser.add_argument( - '--load_optimizer_lr', - default='', + "--colossal_config", type=str, required=True, help="colossal config, which contains zero config and so on" + ) + parser.add_argument( + "--ckpt_path", type=str, required=True, help="location of saving checkpoint, which contains model and optimizer" + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--vscode_debug", action="store_true", help="use vscode to debug") + parser.add_argument("--load_pretrain_model", default="", type=str, help="location of model's checkpoint") + parser.add_argument( + "--load_optimizer_lr", + default="", type=str, - help="location of checkpoint, which contains optimizer, learning rate, epoch, shard and global_step") - parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint") - parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta") - parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing") + help="location of checkpoint, which contains optimizer, learning rate, epoch, shard and global_step", + ) + parser.add_argument("--resume_train", action="store_true", help="whether resume training from a early checkpoint") + parser.add_argument("--mlm", default="bert", type=str, help="model type, bert or deberta") + parser.add_argument("--checkpoint_activations", action="store_true", help="whether to use gradient checkpointing") args = parser.parse_args() return args diff --git a/examples/community/roberta/pretraining/bert_dataset_provider.py b/examples/community/roberta/pretraining/bert_dataset_provider.py index eaf165ed18f4..1d8cf2a910e9 100644 --- a/examples/community/roberta/pretraining/bert_dataset_provider.py +++ b/examples/community/roberta/pretraining/bert_dataset_provider.py @@ -1,5 +1,4 @@ class BertDatasetProviderInterface: - def get_shard(self, index, shuffle=True): raise NotImplementedError diff --git a/examples/community/roberta/pretraining/evaluation.py b/examples/community/roberta/pretraining/evaluation.py index 009242cd1cf5..e1bce48023c3 100644 --- a/examples/community/roberta/pretraining/evaluation.py +++ b/examples/community/roberta/pretraining/evaluation.py @@ -19,23 +19,27 @@ def evaluate(model, args, logger, global_step, criterion): world_size = torch.distributed.get_world_size() with torch.no_grad(): - for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))): - - timers('eval_shard_time').start() + timers("eval_shard_time").start() dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) # evaluate_dataset_provider.prefetch_shard(shard + 1) if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), - total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), - colour='MAGENTA', - smoothing=1) + iterator_data = tqdm( + enumerate(dataset_iterator), + total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), + colour="MAGENTA", + smoothing=1, + ) else: iterator_data = enumerate(dataset_iterator) - for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): - + for ( + step, + batch_data, + ) in ( + iterator_data + ): # tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): # batch_data = pretrain_dataset_provider.get_batch(batch_index) eval_step += 1 input_ids = batch_data[0].cuda() @@ -46,7 +50,7 @@ def evaluate(model, args, logger, global_step, criterion): output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - loss = criterion(output.logits, mlm_label) #prediction_scores + loss = criterion(output.logits, mlm_label) # prediction_scores evaluate_dataset_provider.prefetch_batch() eval_loss += loss.float().item() @@ -58,18 +62,18 @@ def evaluate(model, args, logger, global_step, criterion): if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() - tensorboard_log.log_eval({ - 'loss': cur_loss, - 'ppl': ppl, - 'mins_batch': elapsed_time_per_iteration - }, global_step) + tensorboard_log.log_eval( + {"loss": cur_loss, "ppl": ppl, "mins_batch": elapsed_time_per_iteration}, global_step + ) - eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ - f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}' + eval_log_str = ( + f"evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes " + + f"| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}" + ) logger.info(eval_log_str) - logger.info('-' * 100) - logger.info('') + logger.info("-" * 100) + logger.info("") evaluate_dataset_provider.release_shard() model.train() diff --git a/examples/community/roberta/pretraining/loss.py b/examples/community/roberta/pretraining/loss.py index 989c2bd5c450..636246292809 100644 --- a/examples/community/roberta/pretraining/loss.py +++ b/examples/community/roberta/pretraining/loss.py @@ -1,10 +1,9 @@ import torch -__all__ = ['LossForPretraining'] +__all__ = ["LossForPretraining"] class LossForPretraining(torch.nn.Module): - def __init__(self, vocab_size): super(LossForPretraining, self).__init__() self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-1) @@ -13,5 +12,5 @@ def __init__(self, vocab_size): def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None): masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1)) # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1)) - total_loss = masked_lm_loss #+ next_sentence_loss + total_loss = masked_lm_loss # + next_sentence_loss return total_loss diff --git a/examples/community/roberta/pretraining/model/bert.py b/examples/community/roberta/pretraining/model/bert.py index abdf925d0540..31e3d7075a0c 100644 --- a/examples/community/roberta/pretraining/model/bert.py +++ b/examples/community/roberta/pretraining/model/bert.py @@ -59,7 +59,8 @@ # TokenClassification docstring _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" _TOKEN_CLASS_EXPECTED_OUTPUT = ( - "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] ") + "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " +) _TOKEN_CLASS_EXPECTED_LOSS = 0.01 # QuestionAnswering docstring @@ -109,8 +110,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): import numpy as np import tensorflow as tf except ImportError: - logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions.") + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) raise tf_path = os.path.abspath(tf_checkpoint_path) logger.info(f"Converting TensorFlow checkpoint from {tf_path}") @@ -128,8 +131,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name): + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): logger.info(f"Skipping {'/'.join(name)}") continue pointer = model @@ -209,7 +214,7 @@ def forward( seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length] + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves @@ -236,12 +241,13 @@ def forward( class BertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})") + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) @@ -320,7 +326,7 @@ def forward( position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) @@ -360,7 +366,6 @@ def forward( class BertSelfOutput(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -375,7 +380,6 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): super().__init__() self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) @@ -385,8 +389,9 @@ def __init__(self, config, position_embedding_type=None): def prune_heads(self, heads): if len(heads) == 0: return - heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, - self.self.attention_head_size, self.pruned_heads) + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) @@ -419,12 +424,11 @@ def forward( output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class BertIntermediate(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -440,7 +444,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertOutput(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -455,7 +458,6 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertLayer(nn.Module): - def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -496,14 +498,15 @@ def forward( outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`") + " by setting `config.add_cross_attention=True`" + ) # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None @@ -517,14 +520,15 @@ def forward( output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights # add cross-attn cache to positions 3,4 of present_key_value tuple cross_attn_present_key_value = cross_attention_outputs[-1] present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, - self.seq_len_dim, attention_output) + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) outputs = (layer_output,) + outputs # if decoder, return the attn key/values as the last output @@ -540,7 +544,6 @@ def feed_forward_chunk(self, attention_output): class BertEncoder(nn.Module): - def __init__(self, config): super().__init__() self.config = config @@ -573,14 +576,13 @@ def forward( past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) @@ -617,13 +619,17 @@ def custom_forward(*inputs): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] if v is not None) + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, @@ -634,7 +640,6 @@ def custom_forward(*inputs): class BertPooler(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -650,7 +655,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -668,7 +672,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertLMPredictionHead(nn.Module): - def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) @@ -689,7 +692,6 @@ def forward(self, hidden_states): class BertOnlyMLMHead(nn.Module): - def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) @@ -700,7 +702,6 @@ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: class BertOnlyNSPHead(nn.Module): - def __init__(self, config): super().__init__() self.seq_relationship = nn.Linear(config.hidden_size, 2) @@ -711,7 +712,6 @@ def forward(self, pooled_output): class BertPreTrainingHeads(nn.Module): - def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) @@ -943,8 +943,9 @@ def forward( `past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: @@ -1043,7 +1044,6 @@ def forward( BERT_START_DOCSTRING, ) class BertForPreTraining(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1144,10 +1144,10 @@ def forward( ) -@add_start_docstrings("""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", - BERT_START_DOCSTRING) +@add_start_docstrings( + """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING +) class BertLMHeadModel(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1282,7 +1282,6 @@ def _reorder_cache(self, past, beam_idx): @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING) class BertForMaskedLM(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] @@ -1290,8 +1289,10 @@ def __init__(self, config): super().__init__(config) if config.is_decoder: - logger.warning("If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention.") + logger.warning( + "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) @@ -1357,7 +1358,7 @@ def forward( masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -1380,10 +1381,9 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ raise ValueError("The PAD token should be defined for generation") attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) - dummy_token = torch.full((effective_batch_size, 1), - self.config.pad_token_id, - dtype=torch.long, - device=input_ids.device) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) input_ids = torch.cat([input_ids, dummy_token], dim=1) return {"input_ids": input_ids, "attention_mask": attention_mask} @@ -1394,7 +1394,6 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ BERT_START_DOCSTRING, ) class BertForNextSentencePrediction(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1500,15 +1499,15 @@ def forward( BERT_START_DOCSTRING, ) class BertForSequenceClassification(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) - classifier_dropout = (config.classifier_dropout - if config.classifier_dropout is not None else config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) @@ -1604,13 +1603,13 @@ def forward( BERT_START_DOCSTRING, ) class BertForMultipleChoice(BertPreTrainedModel): - def __init__(self, config): super().__init__(config) self.bert = BertModel(config) - classifier_dropout = (config.classifier_dropout - if config.classifier_dropout is not None else config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, 1) @@ -1650,8 +1649,11 @@ def forward( attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None else None) + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = self.bert( input_ids, @@ -1696,7 +1698,6 @@ def forward( BERT_START_DOCSTRING, ) class BertForTokenClassification(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): @@ -1704,8 +1705,9 @@ def __init__(self, config): self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) - classifier_dropout = (config.classifier_dropout - if config.classifier_dropout is not None else config.hidden_dropout_prob) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) @@ -1782,7 +1784,6 @@ def forward( BERT_START_DOCSTRING, ) class BertForQuestionAnswering(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): diff --git a/examples/community/roberta/pretraining/model/deberta_v2.py b/examples/community/roberta/pretraining/model/deberta_v2.py index 5fc284911e38..c7457942e164 100644 --- a/examples/community/roberta/pretraining/model/deberta_v2.py +++ b/examples/community/roberta/pretraining/model/deberta_v2.py @@ -23,7 +23,6 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss -from transformers import FillMaskPipeline, T5ForConditionalGeneration, T5Tokenizer from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutput, @@ -59,7 +58,6 @@ # Copied from transformers.models.deberta.modeling_deberta.ContextPooler class ContextPooler(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) @@ -138,15 +136,15 @@ def symbolic(g, self, mask, dim): g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx["Byte"], ) - output = masked_fill(g, self, r_mask, - g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))) + output = masked_fill( + g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) + ) output = softmax(g, output, dim) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) # Copied from transformers.models.deberta.modeling_deberta.DropoutContext class DropoutContext(object): - def __init__(self): self.dropout = 0 self.mask = None @@ -249,7 +247,6 @@ def get_context(self): # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm class DebertaV2SelfOutput(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -265,7 +262,6 @@ def forward(self, hidden_states, input_tensor): # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 class DebertaV2Attention(nn.Module): - def __init__(self, config): super().__init__() self.self = DisentangledSelfAttention(config) @@ -303,7 +299,6 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 class DebertaV2Intermediate(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -320,7 +315,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm class DebertaV2Output(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -337,7 +331,6 @@ def forward(self, hidden_states, input_tensor): # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 class DebertaV2Layer(nn.Module): - def __init__(self, config): super().__init__() self.attention = DebertaV2Attention(config) @@ -372,17 +365,14 @@ def forward( class ConvLayer(nn.Module): - def __init__(self, config): super().__init__() kernel_size = getattr(config, "conv_kernel_size", 3) groups = getattr(config, "conv_groups", 1) self.conv_act = getattr(config, "conv_act", "tanh") - self.conv = nn.Conv1d(config.hidden_size, - config.hidden_size, - kernel_size, - padding=(kernel_size - 1) // 2, - groups=groups) + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config @@ -465,10 +455,9 @@ def get_attention_mask(self, attention_mask): def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): if self.relative_attention and relative_pos is None: q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) - relative_pos = build_relative_position(q, - hidden_states.size(-2), - bucket_size=self.position_buckets, - max_position=self.max_relative_positions) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) return relative_pos def forward( @@ -498,14 +487,12 @@ def forward( rel_embeddings = self.get_rel_embedding() output_states = next_kv for i, layer_module in enumerate(self.layer): - if output_hidden_states: all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): return module(*inputs, output_attentions) @@ -550,9 +537,9 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput(last_hidden_state=output_states, - hidden_states=all_hidden_states, - attentions=all_attentions) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) def make_log_bucket_position(relative_pos, bucket_size, max_position): @@ -625,8 +612,10 @@ class DisentangledSelfAttention(nn.Module): def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0: - raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})") + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) self.num_attention_heads = config.num_attention_heads _attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) @@ -719,22 +708,28 @@ def forward( attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) - rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, - scale_factor) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) if rel_att is not None: attention_scores = attention_scores + rel_att attention_scores = attention_scores - attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2), - attention_scores.size(-1)) + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) # bsz x height x length x dimension attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) attention_probs = self.dropout(attention_probs) - context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), - value_layer) - context_layer = (context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), - context_layer.size(-1)).permute(0, 2, 1, 3).contiguous()) + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) if output_attentions: @@ -745,10 +740,9 @@ def forward( def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): if relative_pos is None: q = query_layer.size(-2) - relative_pos = build_relative_position(q, - key_layer.size(-2), - bucket_size=self.position_buckets, - max_position=self.max_relative_positions) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) if relative_pos.dim() == 2: relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: @@ -766,22 +760,25 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # rel_embeddings = rel_embeddings.unsqueeze(0) # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) if self.share_att_key: - pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings), - self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1) + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) else: if "c2p" in self.pos_att_type: - pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings), - self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, - 1) # .split(self.all_head_size, dim=-1) + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) if "p2c" in self.pos_att_type: - pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings), - self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, - 1) # .split(self.all_head_size, dim=-1) + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) score = 0 # content->position @@ -792,9 +789,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ c2p_att = torch.gather( c2p_att, dim=-1, - index=c2p_pos.squeeze(0).expand([query_layer.size(0), - query_layer.size(1), - relative_pos.size(-1)]), + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), ) score += c2p_att / scale @@ -817,9 +812,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ p2c_att = torch.gather( p2c_att, dim=-1, - index=p2c_pos.squeeze(0).expand([query_layer.size(0), - key_layer.size(-2), - key_layer.size(-2)]), + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), ).transpose(-1, -2) score += p2c_att / scale @@ -999,7 +992,6 @@ def _set_gradient_checkpointing(self, module, value=False): ) # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 class DebertaV2Model(DebertaV2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1042,8 +1034,9 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: @@ -1100,7 +1093,7 @@ def forward( sequence_output = encoded_layers[-1] if not return_dict: - return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2):] + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] return BaseModelOutput( last_hidden_state=sequence_output, @@ -1174,7 +1167,7 @@ def forward( masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -1191,7 +1184,6 @@ def forward( # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta class DebertaV2PredictionHeadTransform(nn.Module): - def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -1210,7 +1202,6 @@ def forward(self, hidden_states): # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta class DebertaV2LMPredictionHead(nn.Module): - def __init__(self, config): super().__init__() self.transform = DebertaV2PredictionHeadTransform(config) @@ -1232,7 +1223,6 @@ def forward(self, hidden_states): # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta class DebertaV2OnlyMLMHead(nn.Module): - def __init__(self, config): super().__init__() self.predictions = DebertaV2LMPredictionHead(config) @@ -1251,7 +1241,6 @@ def forward(self, sequence_output): ) # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1331,8 +1320,9 @@ def forward( label_index = (labels >= 0).nonzero() labels = labels.long() if label_index.size(0) > 0: - labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), - logits.size(1))) + labeled_logits = torch.gather( + logits, 0, label_index.expand(label_index.size(0), logits.size(1)) + ) labels = torch.gather(labels, 0, label_index.view(-1)) loss_fct = CrossEntropyLoss() loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) @@ -1357,10 +1347,9 @@ def forward( output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutput(loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions) + return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) @add_start_docstrings( @@ -1435,10 +1424,9 @@ def forward( output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput(loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions) + return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) @add_start_docstrings( @@ -1550,7 +1538,6 @@ def forward( DEBERTA_START_DOCSTRING, ) class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -1606,8 +1593,11 @@ def forward( flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - flat_inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None else None) + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) outputs = self.deberta( flat_input_ids, diff --git a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py index 72c7bd852a40..09677a6195cb 100644 --- a/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py +++ b/examples/community/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -1,5 +1,3 @@ -import json -import logging import os import random import time @@ -12,14 +10,10 @@ from bert_dataset_provider import BertDatasetProviderInterface from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler - -import colossalai.utils as utils # Workaround because python functions are not picklable class WorkerInitObj(object): - def __init__(self, seed): self.seed = seed @@ -28,44 +22,46 @@ def __call__(self, id): random.seed(self.seed + id) -def create_pretraining_dataset(input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init, - data_sampler): +def create_pretraining_dataset( + input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init, data_sampler +): train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) - train_dataloader = DataLoader(train_data, - sampler=data_sampler(train_data), - batch_size=train_batch_size, - num_workers=num_workers, - worker_init_fn=worker_init, - pin_memory=True) + train_dataloader = DataLoader( + train_data, + sampler=data_sampler(train_data), + batch_size=train_batch_size, + num_workers=num_workers, + worker_init_fn=worker_init, + pin_memory=True, + ) return train_dataloader, len(train_data) class pretraining_dataset(Dataset): - def __init__(self, input_file, max_predictions_per_seq): self.input_file = input_file self.max_predictions_per_seq = max_predictions_per_seq f = h5py.File(input_file, "r") - keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'] + keys = ["input_ids", "input_mask", "segment_ids", "masked_lm_positions"] self.inputs = [np.asarray(f[key][:]) for key in keys] f.close() def __len__(self): - 'Denotes the total number of samples' + "Denotes the total number of samples" return len(self.inputs[0]) def __getitem__(self, index): - [input_ids, input_mask, segment_ids, masked_lm_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy( - np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs) + torch.from_numpy(input[index].astype(np.int64)) + if indice < 5 + else torch.from_numpy(np.asarray(input[index].astype(np.int64))) + for indice, input in enumerate(self.inputs) ] return [input_ids, input_mask, segment_ids, masked_lm_labels] class NvidiaBertDatasetProvider(BertDatasetProviderInterface): - def __init__(self, args, evaluate=False): self.num_workers = args.num_workers self.max_seq_length = args.max_seq_length @@ -86,13 +82,13 @@ def __init__(self, args, evaluate=False): self.dataset_files = [ os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) - if os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f + if os.path.isfile(os.path.join(args.data_path_prefix, f)) and "h5" in f ] else: self.dataset_files = [ os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) - if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f + if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and "h5" in f ] self.dataset_files.sort() @@ -120,7 +116,8 @@ def get_shard(self, index): num_workers=self.num_workers, train_batch_size=self.train_micro_batch_size_per_gpu, worker_init=self.worker_init, - data_sampler=self.data_sampler) + data_sampler=self.data_sampler, + ) else: self.train_dataloader, sample_count = self.dataset_future.result(timeout=None) @@ -136,9 +133,15 @@ def release_shard(self): def prefetch_shard(self, index): self.data_file = self._get_shard_file(index) - self.dataset_future = self.pool.submit(create_pretraining_dataset, self.data_file, self.max_predictions_per_seq, - self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init, - self.data_sampler) + self.dataset_future = self.pool.submit( + create_pretraining_dataset, + self.data_file, + self.max_predictions_per_seq, + self.num_workers, + self.train_micro_batch_size_per_gpu, + self.worker_init, + self.data_sampler, + ) def get_batch(self, batch_iter): return batch_iter diff --git a/examples/community/roberta/pretraining/pretrain_utils.py b/examples/community/roberta/pretraining/pretrain_utils.py index e6a393a57dda..1370b413b712 100644 --- a/examples/community/roberta/pretraining/pretrain_utils.py +++ b/examples/community/roberta/pretraining/pretrain_utils.py @@ -1,24 +1,12 @@ -import logging import os import sys import torch import transformers -from torch.optim import AdamW -from transformers import ( - AutoModelForMaskedLM, - AutoTokenizer, - BertForPreTraining, - GPT2Config, - GPT2LMHeadModel, - RobertaConfig, - RobertaForMaskedLM, - get_linear_schedule_with_warmup, -) +from transformers import get_linear_schedule_with_warmup from colossalai.legacy.core import global_context as gpc -from colossalai.nn.lr_scheduler import LinearWarmupLR -from colossalai.nn.optimizer import FusedAdam, HybridAdam +from colossalai.nn.optimizer import HybridAdam sys.path.append(os.getcwd()) from collections import OrderedDict @@ -27,7 +15,7 @@ from model.bert import BertForMaskedLM from model.deberta_v2 import DebertaV2ForMaskedLM -__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] +__all__ = ["get_model", "get_optimizer", "get_lr_scheduler", "get_dataloader_for_pretraining"] def get_new_state_dict(state_dict, start_index=13): @@ -39,7 +27,6 @@ def get_new_state_dict(state_dict, start_index=13): class LMModel(nn.Module): - def __init__(self, model, config, args): super().__init__() @@ -55,11 +42,10 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): def get_model(args, logger): - - if args.mlm == 'bert': + if args.mlm == "bert": config = transformers.BertConfig.from_json_file(args.bert_config) model = BertForMaskedLM(config) - elif args.mlm == 'deberta_v2': + elif args.mlm == "deberta_v2": config = transformers.DebertaV2Config.from_json_file(args.bert_config) model = DebertaV2ForMaskedLM(config) else: @@ -68,11 +54,13 @@ def get_model(args, logger): if len(args.load_pretrain_model) > 0: assert os.path.exists(args.load_pretrain_model) # load_checkpoint(args.load_pretrain_model, model, strict=False) - m_state_dict = torch.load(args.load_pretrain_model, - map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) + m_state_dict = torch.load( + args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}") + ) # new_state_dict = get_new_state_dict(m_state_dict) - model.load_state_dict(m_state_dict, - strict=True) # must insure that every process have identical parameters !!!!!!! + model.load_state_dict( + m_state_dict, strict=True + ) # must insure that every process have identical parameters !!!!!!! logger.info("load model success") numel = sum([p.numel() for p in model.parameters()]) @@ -85,40 +73,36 @@ def get_model(args, logger): def get_optimizer(model, lr): param_optimizer = list(model.named_parameters()) - no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] + no_decay = ["bias", "gamma", "beta", "LayerNorm"] # configure the weight decay for bert models - optimizer_grouped_parameters = [{ - 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], - 'weight_decay': 0.1 - }, { - 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], - 'weight_decay': 0.0 - }] + optimizer_grouped_parameters = [ + {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.1}, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) return optimizer def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): # warmup_steps = int(total_steps * warmup_ratio) - lr_scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=warmup_steps, - num_training_steps=total_steps, - last_epoch=last_epoch) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch + ) # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) return lr_scheduler def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step): - model_path = path + '_pytorch_model.bin' - optimizer_lr_path = path + '.op_lrs' + model_path = path + "_pytorch_model.bin" + optimizer_lr_path = path + ".op_lrs" checkpoint = {} - checkpoint['optimizer'] = optimizer.state_dict() - checkpoint['lr_scheduler'] = lr_scheduler.state_dict() - checkpoint['epoch'] = epoch - checkpoint['shard'] = shard - checkpoint['global_step'] = global_step - model_state = model.state_dict() #each process must run model.state_dict() + checkpoint["optimizer"] = optimizer.state_dict() + checkpoint["lr_scheduler"] = lr_scheduler.state_dict() + checkpoint["epoch"] = epoch + checkpoint["shard"] = shard + checkpoint["global_step"] = global_step + model_state = model.state_dict() # each process must run model.state_dict() if gpc.get_global_rank() == 0: torch.save(checkpoint, optimizer_lr_path) torch.save(model_state, model_path) diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py index fa6457cab328..5396de6935cb 100644 --- a/examples/community/roberta/pretraining/run_pretraining.py +++ b/examples/community/roberta/pretraining/run_pretraining.py @@ -17,16 +17,13 @@ import colossalai from colossalai.context import ParallelMode -from colossalai.legacy.core import global_context as gpc -from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import GeminiOptimizer def main(): - args = parse_args() launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) @@ -37,20 +34,17 @@ def main(): logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) if args.vscode_debug: - colossalai.launch(config={}, - rank=args.rank, - world_size=args.world_size, - host=args.host, - port=args.port, - backend=args.backend) + colossalai.launch( + config={}, rank=args.rank, world_size=args.world_size, host=args.host, port=args.port, backend=args.backend + ) args.local_rank = -1 args.log_interval = 1 else: - colossalai.launch_from_torch(config={}) # args.colossal_config + colossalai.launch_from_torch(config={}) # args.colossal_config args.local_rank = int(os.environ["LOCAL_RANK"]) logger.info( - f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + - f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}' + f"launch_from_torch, world size: {torch.distributed.get_world_size()} | " + + f"ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}" ) log_args(logger, args) @@ -59,7 +53,7 @@ def main(): set_global_variables(launch_time, args.tensorboard_path) world_size = torch.distributed.get_world_size() - init_dev = get_current_device() + get_current_device() # build model, optimizer and criterion if args.distplan.startswith("CAI"): @@ -72,10 +66,9 @@ def main(): raise RuntimeError("You can only use shardinit with CAI_Gemini") # build GPT model - with ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_dist_spec=default_dist_spec, - default_pg=shard_pg): + with ColoInitContext( + device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg + ): config, model, numel = get_model(args, logger) # assign running configurations @@ -83,13 +76,15 @@ def main(): if args.distplan.startswith("CAI_ZeRO"): optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) elif args.distplan == "CAI_Gemini": - gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, - device=get_current_device(), - placement_policy=args.placement, - pin_memory=True, - hidden_dim=model.config.hidden_size, - search_range_m=128) - optim_config = dict(gpu_margin_mem_ratio=0.) + gemini_config = dict( + strict_ddp_mode=args.tp_degree == 1, + device=get_current_device(), + placement_policy=args.placement, + pin_memory=True, + hidden_dim=model.config.hidden_size, + search_range_m=128, + ) + optim_config = dict(gpu_margin_mem_ratio=0.0) else: raise RuntimeError @@ -109,7 +104,7 @@ def main(): model = zero_model_wrapper(model, zero_stage, gemini_config) optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) - logger.info(get_mem_info(prefix='After init optim, ')) + logger.info(get_mem_info(prefix="After init optim, ")) else: config, model, numel = get_model(args, logger) @@ -118,13 +113,19 @@ def main(): if torch.distributed.get_rank() == 0: os.mkdir(os.path.join(args.ckpt_path, launch_time)) - logger.info(f'Model numel: {numel}') + logger.info(f"Model numel: {numel}") get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) # 144003367 is is the length of the entire dataset # len(dataloader) - steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size + steps_per_epoch = ( + 144003367 + // world_size + // args.train_micro_batch_size_per_gpu + // args.gradient_accumulation_steps + // args.refresh_bucket_size + ) total_steps = steps_per_epoch * args.epoch lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) @@ -134,25 +135,25 @@ def main(): global_step = 0 if args.resume_train: assert os.path.exists(args.load_optimizer_lr) - o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu') - o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 - optimizer.load_state_dict(o_l_state_dict['optimizer']) + o_l_state_dict = torch.load(args.load_optimizer_lr, map_location="cpu") + o_l_state_dict["lr_scheduler"]["last_epoch"] = o_l_state_dict["lr_scheduler"]["last_epoch"] - 1 + optimizer.load_state_dict(o_l_state_dict["optimizer"]) # o_l_state_dict['lr_scheduler']['last_epoch'] - lr_scheduler = get_lr_scheduler(optimizer, - total_steps=total_steps, - last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) + lr_scheduler = get_lr_scheduler( + optimizer, total_steps=total_steps, last_epoch=o_l_state_dict["lr_scheduler"]["last_epoch"] + ) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") # if you want delete the above three code, must move the model to gpu. Because in optimizer.step() - lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) + lr_scheduler.load_state_dict(o_l_state_dict["lr_scheduler"]) - start_epoch = o_l_state_dict['epoch'] - start_shard = o_l_state_dict['shard'] + 1 + start_epoch = o_l_state_dict["epoch"] + start_shard = o_l_state_dict["shard"] + 1 # global_step = o_l_state_dict['global_step'] + 1 logger.info( - f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}' + f"resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}" ) criterion = LossForPretraining(config.vocab_size) @@ -160,34 +161,32 @@ def main(): # build dataloader pretrain_dataset_provider = NvidiaBertDatasetProvider(args) - logger.info(get_mem_info(prefix='After init model, ')) + logger.info(get_mem_info(prefix="After init model, ")) - best_loss = None eval_loss = 0 train_loss = 0 timers = get_timers() - timers('interval_time').start() - timers('epoch_time').start() - timers('shard_time').start() + timers("interval_time").start() + timers("epoch_time").start() + timers("shard_time").start() for epoch in range(start_epoch, args.epoch): - for shard in range(start_shard, len(os.listdir(args.data_path_prefix))): - dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), - total=(total_length // args.train_micro_batch_size_per_gpu // world_size), - colour='cyan', - smoothing=1) + iterator_data = tqdm( + enumerate(dataset_iterator), + total=(total_length // args.train_micro_batch_size_per_gpu // world_size), + colour="cyan", + smoothing=1, + ) else: iterator_data = enumerate(dataset_iterator) model.train() for step, batch_data in iterator_data: - # batch_data = pretrain_dataset_provider.get_batch(batch_index) input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") attention_mask = batch_data[1].cuda(f"cuda:{torch.cuda.current_device()}") @@ -209,56 +208,70 @@ def main(): global_step += 1 - if global_step % args.log_interval == 0 and global_step != 0 \ - and torch.distributed.get_rank() == 0: - elapsed_time = timers('interval_time').elapsed(reset=False) + if global_step % args.log_interval == 0 and global_step != 0 and torch.distributed.get_rank() == 0: + elapsed_time = timers("interval_time").elapsed(reset=False) elapsed_time_per_iteration = elapsed_time / global_step samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( - numel, args, config, elapsed_time, global_step, world_size) + numel, args, config, elapsed_time, global_step, world_size + ) cur_loss = train_loss / args.log_interval current_lr = lr_scheduler.get_last_lr()[0] - log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ - f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}' + log_str = ( + f"| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes " + + f"| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}" + ) logger.info(log_str, print_=False) if args.wandb: tensorboard_log = get_tensorboard_writer() tensorboard_log.log_train( { - 'lr': current_lr, - 'loss': cur_loss, - 'ppl': math.exp(cur_loss), - 'mins_batch': elapsed_time_per_iteration - }, global_step) + "lr": current_lr, + "loss": cur_loss, + "ppl": math.exp(cur_loss), + "mins_batch": elapsed_time_per_iteration, + }, + global_step, + ) train_loss = 0 logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') - logger.info('*' * 100) + logger.info("*" * 100) eval_loss += evaluate(model, args, logger, global_step, criterion) - save_ckpt(model, optimizer, lr_scheduler, - os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, - shard, global_step) + save_ckpt( + model, + optimizer, + lr_scheduler, + os.path.join(args.ckpt_path, launch_time, f"epoch-{epoch}_shard-{shard}_" + launch_time), + epoch, + shard, + global_step, + ) eval_loss /= len(os.listdir(args.data_path_prefix)) logger.info( f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' - + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') - logger.info('-' * 100) + + f"eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}" + ) + logger.info("-" * 100) if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() - tensorboard_log.log_eval({ - 'all_eval_shard_loss': eval_loss, - }, epoch) + tensorboard_log.log_eval( + { + "all_eval_shard_loss": eval_loss, + }, + epoch, + ) start_shard = 0 eval_loss = 0 pretrain_dataset_provider.release_shard() - logger.info('Congratulation, training has finished!!!') + logger.info("Congratulation, training has finished!!!") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/community/roberta/pretraining/utils/WandbLog.py b/examples/community/roberta/pretraining/utils/WandbLog.py index b68ba8387dcd..d73393c348d8 100644 --- a/examples/community/roberta/pretraining/utils/WandbLog.py +++ b/examples/community/roberta/pretraining/utils/WandbLog.py @@ -6,7 +6,6 @@ class WandbLog: - @classmethod def init_wandb(cls, project, notes=None, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): wandb.init(project=project, notes=notes, name=name, config=config) @@ -23,7 +22,6 @@ def log(cls, result, model=None, gradient=None): class TensorboardLog: - def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), config=None): if not os.path.exists(location): os.mkdir(location) @@ -31,12 +29,12 @@ def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localt def log_train(self, result, step): for k, v in result.items(): - self.writer.add_scalar(f'{k}/train', v, step) + self.writer.add_scalar(f"{k}/train", v, step) def log_eval(self, result, step): for k, v in result.items(): - self.writer.add_scalar(f'{k}/eval', v, step) + self.writer.add_scalar(f"{k}/eval", v, step) def log_zeroshot(self, result, step): for k, v in result.items(): - self.writer.add_scalar(f'{k}_acc/eval', v, step) + self.writer.add_scalar(f"{k}_acc/eval", v, step) diff --git a/examples/community/roberta/pretraining/utils/exp_util.py b/examples/community/roberta/pretraining/utils/exp_util.py index 1fcaa428b277..e95b6efda4c8 100644 --- a/examples/community/roberta/pretraining/utils/exp_util.py +++ b/examples/community/roberta/pretraining/utils/exp_util.py @@ -12,8 +12,8 @@ def logging(s, log_path, print_=True, log_=True): if print_: print(s) if log_: - with open(log_path, 'a+') as f_log: - f_log.write(s + '\n') + with open(log_path, "a+") as f_log: + f_log.write(s + "\n") def get_logger(log_path, **kwargs): @@ -22,22 +22,22 @@ def get_logger(log_path, **kwargs): def create_exp_dir(dir_path, scripts_to_save=None, debug=False): if debug: - print('Debug Mode : no experiment dir created') + print("Debug Mode : no experiment dir created") return functools.partial(logging, log_path=None, log_=False) if not os.path.exists(dir_path): os.makedirs(dir_path) - print('Experiment dir : {}'.format(dir_path)) + print("Experiment dir : {}".format(dir_path)) if scripts_to_save is not None: - script_path = os.path.join(dir_path, 'scripts') + script_path = os.path.join(dir_path, "scripts") if not os.path.exists(script_path): os.makedirs(script_path) for script in scripts_to_save: - dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) + dst_file = os.path.join(dir_path, "scripts", os.path.basename(script)) shutil.copyfile(script, dst_file) - return get_logger(log_path=os.path.join(dir_path, 'log.txt')) + return get_logger(log_path=os.path.join(dir_path, "log.txt")) def get_cpu_mem(): @@ -48,8 +48,8 @@ def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_mem_info(prefix=""): + return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_tflops(model_numel, batch_size, seq_len, step_time): @@ -59,11 +59,12 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): def get_parameters_in_billions(model, world_size=1): gpus_per_model = world_size - approx_parameters_in_billions = sum([ - sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement() - for p in model_module.parameters()]) - for model_module in model - ]) + approx_parameters_in_billions = sum( + [ + sum([p.ds_numel if hasattr(p, "ds_id") else p.nelement() for p in model_module.parameters()]) + for model_module in model + ] + ) return approx_parameters_in_billions * gpus_per_model / (1e9) @@ -71,13 +72,13 @@ def get_parameters_in_billions(model, world_size=1): def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1): gpus_per_model = 1 batch_size = args.train_micro_batch_size_per_gpu - samples_per_model = batch_size * args.max_seq_length - model_replica_count = world_size / gpus_per_model + batch_size * args.max_seq_length + world_size / gpus_per_model approx_parameters_in_billions = numel elapsed_time_per_iter = iteration_time / total_iterations samples_per_second = batch_size / elapsed_time_per_iter - #flops calculator + # flops calculator hidden_size = config.hidden_size num_layers = config.num_hidden_layers vocab_size = config.vocab_size @@ -87,9 +88,9 @@ def throughput_calculator(numel, args, config, iteration_time, total_iterations, # The factor of 4 is when used with activation check-pointing, # otherwise it will be 3. checkpoint_activations_factor = 4 if args.checkpoint_activations else 3 - flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * - (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + - (vocab_size / (16. * num_layers * hidden_size))) + flops_per_iteration = ( + 24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2) + ) * (1.0 + (args.max_seq_length / (6.0 * hidden_size)) + (vocab_size / (16.0 * num_layers * hidden_size))) tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12)) return samples_per_second, tflops, approx_parameters_in_billions @@ -106,9 +107,9 @@ def synchronize(): def log_args(logger, args): - logger.info('--------args----------') - message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) - message += '\n' - message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()]) + logger.info("--------args----------") + message = "\n".join([f"{k:<30}: {v}" for k, v in vars(args).items()]) + message += "\n" + message += "\n".join([f"{k:<30}: {v}" for k, v in gpc.config.items()]) logger.info(message) - logger.info('--------args----------\n') + logger.info("--------args----------\n") diff --git a/examples/community/roberta/pretraining/utils/global_vars.py b/examples/community/roberta/pretraining/utils/global_vars.py index 9eef19e71614..176c0a5b3474 100644 --- a/examples/community/roberta/pretraining/utils/global_vars.py +++ b/examples/community/roberta/pretraining/utils/global_vars.py @@ -16,21 +16,21 @@ def set_global_variables(launch_time, tensorboard_path): def _set_timers(): """Initialize timers.""" global _GLOBAL_TIMERS - _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') + _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers") _GLOBAL_TIMERS = Timers() def _set_tensorboard_writer(launch_time, tensorboard_path): """Set tensorboard writer.""" global _GLOBAL_TENSORBOARD_WRITER - _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer') + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, "tensorboard writer") if torch.distributed.get_rank() == 0: - _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time) + _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f"/{launch_time}", launch_time) def get_timers(): """Return timers.""" - _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') + _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers") return _GLOBAL_TIMERS @@ -42,12 +42,12 @@ def get_tensorboard_writer(): def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" - assert var is not None, '{} is not initialized.'.format(name) + assert var is not None, "{} is not initialized.".format(name) def _ensure_var_is_not_initialized(var, name): """Make sure the input variable is not None.""" - assert var is None, '{} is already initialized.'.format(name) + assert var is None, "{} is already initialized.".format(name) class _Timer: @@ -68,9 +68,9 @@ def start(self): def stop(self): """Stop the timer.""" - assert self.started_, 'timer is not started' + assert self.started_, "timer is not started" torch.cuda.synchronize() - self.elapsed_ += (time.time() - self.start_time) + self.elapsed_ += time.time() - self.start_time self.started_ = False def reset(self): @@ -114,15 +114,15 @@ def write(self, names, writer, iteration, normalizer=1.0, reset=False): assert normalizer > 0.0 for name in names: value = self.timers[name].elapsed(reset=reset) / normalizer - writer.add_scalar(name + '-time', value, iteration) + writer.add_scalar(name + "-time", value, iteration) def log(self, names, normalizer=1.0, reset=True): """Log a group of timers.""" assert normalizer > 0.0 - string = 'time (ms)' + string = "time (ms)" for name in names: elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer - string += ' | {}: {:.2f}'.format(name, elapsed_time) + string += " | {}: {:.2f}".format(name, elapsed_time) if torch.distributed.is_initialized(): if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): print(string, flush=True) diff --git a/examples/community/roberta/pretraining/utils/logger.py b/examples/community/roberta/pretraining/utils/logger.py index 75c9bf4bef25..9913892b89e9 100644 --- a/examples/community/roberta/pretraining/utils/logger.py +++ b/examples/community/roberta/pretraining/utils/logger.py @@ -1,16 +1,14 @@ import logging -import os import torch.distributed as dist -logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', - level=logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO +) logger = logging.getLogger(__name__) -class Logger(): - +class Logger: def __init__(self, log_path, cuda=False, debug=False): self.logger = logging.getLogger(__name__) self.cuda = cuda @@ -23,8 +21,8 @@ def info(self, message, log_=True, print_=True, *args, **kwargs): self.logger.info(message, *args, **kwargs) if log_: - with open(self.log_path, 'a+') as f_log: - f_log.write(message + '\n') + with open(self.log_path, "a+") as f_log: + f_log.write(message + "\n") def error(self, message, *args, **kwargs): self.logger.error(message, *args, **kwargs) diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index 0c7f42ded318..b63896524909 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -132,7 +132,7 @@ bash train_colossalai.sh ``` It is important for you to configure your volume mapping in order to get the best training experience. -1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v :/data/scratch`, where you need to replace `` with the actual data path on your machine. Notice that within docker we need to transform the Windows path to a Linux one, e.g. `C:\User\Desktop` into `/mnt/c/User/Desktop`. +1. **Mandatory**, mount your prepared data to `/data/scratch` via `-v :/data/scratch`, where you need to replace `` with the actual data path on your machine. Notice that within docker we need to transform the Windows path to a Linux one, e.g. `C:\User\Desktop` into `/mnt/c/User/Desktop`. 2. **Recommended**, store the downloaded model weights to your host machine instead of the container directory via `-v :/root/.cache/huggingface`, where you need to replace the `` with the actual path. In this way, you don't have to repeatedly download the pretrained weights for every `docker run`. 3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command. diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml index f3ae3ddb5ff6..72dc05b649a4 100644 --- a/examples/images/diffusion/configs/train_ddp.yaml +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -80,7 +80,7 @@ data: lightning: trainer: - accelerator: 'gpu' + accelerator: 'gpu' devices: 8 log_gpu_memory: all max_epochs: 2 diff --git a/examples/images/diffusion/ldm/data/base.py b/examples/images/diffusion/ldm/data/base.py index a12492c95a16..11bd0c5954a2 100644 --- a/examples/images/diffusion/ldm/data/base.py +++ b/examples/images/diffusion/ldm/data/base.py @@ -1,17 +1,15 @@ -import math import os -from abc import abstractmethod import cv2 import numpy as np import torch -from torch.utils.data import ChainDataset, ConcatDataset, Dataset, IterableDataset +from torch.utils.data import IterableDataset class Txt2ImgIterableBaseDataset(IterableDataset): - ''' + """ Define an interface to make the IterableDatasets for text2img data chainable - ''' + """ def __init__(self, file_path: str, rank, world_size): super().__init__() @@ -20,8 +18,8 @@ def __init__(self, file_path: str, rank, world_size): self.file_list = [] self.txt_list = [] self.info = self._get_file_info(file_path) - self.start = self.info['start'] - self.end = self.info['end'] + self.start = self.info["start"] + self.end = self.info["end"] self.rank = rank self.world_size = world_size @@ -33,7 +31,7 @@ def __init__(self, file_path: str, rank, world_size): self.num_records = self.end - self.start self.valid_ids = [i for i in range(self.end)] - print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.") def __len__(self): # return self.iter_end - self.iter_start @@ -48,7 +46,7 @@ def _sample_generator(self, start, end): for idx in range(start, end): file_name = self.file_list[idx] txt_name = self.txt_list[idx] - f_ = open(txt_name, 'r') + f_ = open(txt_name, "r") txt_ = f_.read() f_.close() image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1) @@ -57,18 +55,17 @@ def _sample_generator(self, start, end): yield {"txt": txt_, "image": image} def _get_file_info(self, file_path): - info = \ - { + info = { "start": 1, "end": 0, } - self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i] + self.folder_list = [file_path + i for i in os.listdir(file_path) if "." not in i] for folder in self.folder_list: - files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i] - txts = [k.replace('jpg', 'txt') for k in files] + files = [folder + "/" + i for i in os.listdir(folder) if "jpg" in i] + txts = [k.replace("jpg", "txt") for k in files] self.file_list.extend(files) self.txt_list.extend(txts) - info['end'] = len(self.file_list) + info["end"] = len(self.file_list) # with open(file_path, 'r') as fin: # for _ in enumerate(fin): # info['end'] += 1 diff --git a/examples/images/diffusion/ldm/data/cifar10.py b/examples/images/diffusion/ldm/data/cifar10.py index 53cd61263b47..85c6e1b5dd38 100644 --- a/examples/images/diffusion/ldm/data/cifar10.py +++ b/examples/images/diffusion/ldm/data/cifar10.py @@ -1,15 +1,16 @@ +import json +from pathlib import Path from typing import Dict -import numpy as np -from omegaconf import DictConfig, ListConfig + import torch -from torch.utils.data import Dataset -from pathlib import Path -import json -from PIL import Image -from torchvision import transforms +from datasets import load_dataset from einops import rearrange from ldm.util import instantiate_from_config -from datasets import load_dataset +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders @@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): """ list_of_paths = [] if isinstance(paths, (Dict, DictConfig)): - assert caption_files is None, \ - "Caption files not yet supported for repeats" + assert caption_files is None, "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): - list_of_paths.extend([folder_path]*repeats) + list_of_paths.extend([folder_path] * repeats) paths = list_of_paths if caption_files is not None: @@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) + class FolderData(Dataset): - def __init__(self, + def __init__( + self, root_dir, caption_file=None, image_transforms=[], @@ -40,7 +42,7 @@ def __init__(self, default_caption="", postprocess=None, return_paths=False, - ) -> None: + ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) @@ -75,12 +77,12 @@ def __init__(self, self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))] + ) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms - def __len__(self): if self.captions is not None: return len(self.captions.keys()) @@ -94,7 +96,7 @@ def __getitem__(self, index): caption = self.captions.get(chosen, None) if caption is None: caption = self.default_caption - filename = self.root_dir/chosen + filename = self.root_dir / chosen else: filename = self.paths[index] @@ -119,22 +121,23 @@ def process_im(self, im): im = im.convert("RGB") return self.tform(im) + def hf_dataset( name, image_transforms=[], image_column="img", label_column="label", text_column="txt", - split='train', - image_key='image', - caption_key='txt', - ): - """Make huggingface dataset with appropriate list of transforms applied - """ + split="train", + image_key="image", + caption_key="txt", +): + """Make huggingface dataset with appropriate list of transforms applied""" ds = load_dataset(name, split=split) image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))] + ) tform = transforms.Compose(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" @@ -144,7 +147,18 @@ def pre_process(examples): processed = {} processed[image_key] = [tform(im) for im in examples[image_column]] - label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + label_to_text_dict = { + 0: "airplane", + 1: "automobile", + 2: "bird", + 3: "cat", + 4: "deer", + 5: "dog", + 6: "frog", + 7: "horse", + 8: "ship", + 9: "truck", + } processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] @@ -153,6 +167,7 @@ def pre_process(examples): ds.set_transform(pre_process) return ds + class TextOnly(Dataset): def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): """Returns only captions with dummy images""" @@ -166,7 +181,7 @@ def __init__(self, captions, output_size, image_key="image", caption_key="txt", if n_gpus > 1: # hack to make sure that all the captions appear on each gpu - repeated = [n_gpus*[x] for x in self.captions] + repeated = [n_gpus * [x] for x in self.captions] self.captions = [] [self.captions.extend(x) for x in repeated] @@ -175,10 +190,10 @@ def __len__(self): def __getitem__(self, index): dummy_im = torch.zeros(3, self.output_size, self.output_size) - dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + dummy_im = rearrange(dummy_im * 2.0 - 1.0, "c h w -> h w c") return {self.image_key: dummy_im, self.caption_key: self.captions[index]} def _load_caption_file(self, filename): - with open(filename, 'rt') as f: + with open(filename, "rt") as f: captions = f.readlines() - return [x.strip('\n') for x in captions] \ No newline at end of file + return [x.strip("\n") for x in captions] diff --git a/examples/images/diffusion/ldm/data/imagenet.py b/examples/images/diffusion/ldm/data/imagenet.py index 1c473f9c6965..8483e16ab23a 100644 --- a/examples/images/diffusion/ldm/data/imagenet.py +++ b/examples/images/diffusion/ldm/data/imagenet.py @@ -1,32 +1,35 @@ -import os, yaml, pickle, shutil, tarfile, glob -import cv2 +import glob +import os +import pickle +import shutil +import tarfile +from functools import partial + import albumentations -import PIL +import cv2 import numpy as np +import PIL +import taming.data.utils as tdu import torchvision.transforms.functional as TF +import yaml +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light from omegaconf import OmegaConf -from functools import partial from PIL import Image -from tqdm import tqdm +from taming.data.imagenet import ImagePaths, download, give_synsets_from_indices, retrieve, str_to_indices from torch.utils.data import Dataset, Subset - -import taming.data.utils as tdu -from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve -from taming.data.imagenet import ImagePaths - -from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light +from tqdm import tqdm def synset2idx(path_to_yaml="data/index_synset.yaml"): with open(path_to_yaml) as f: di2s = yaml.load(f) - return dict((v,k) for k,v in di2s.items()) + return dict((v, k) for k, v in di2s.items()) class ImageNetBase(Dataset): def __init__(self, config=None): self.config = config or OmegaConf.create() - if not type(self.config)==dict: + if not type(self.config) == dict: self.config = OmegaConf.to_container(self.config) self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) self.process_images = True # if False we skip loading & processing images and self.data contains filepaths @@ -46,9 +49,11 @@ def _prepare(self): raise NotImplementedError() def _filter_relpaths(self, relpaths): - ignore = set([ - "n06596364_9591.JPEG", - ]) + ignore = set( + [ + "n06596364_9591.JPEG", + ] + ) relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] if "sub_indices" in self.config: indices = str_to_indices(self.config["sub_indices"]) @@ -67,20 +72,19 @@ def _prepare_synset_to_human(self): SIZE = 2655750 URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" self.human_dict = os.path.join(self.root, "synset_human.txt") - if (not os.path.exists(self.human_dict) or - not os.path.getsize(self.human_dict)==SIZE): + if not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict) == SIZE: download(URL, self.human_dict) def _prepare_idx_to_synset(self): URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" self.idx2syn = os.path.join(self.root, "index_synset.yaml") - if (not os.path.exists(self.idx2syn)): + if not os.path.exists(self.idx2syn): download(URL, self.idx2syn) def _prepare_human_to_integer_label(self): URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") - if (not os.path.exists(self.human2integer)): + if not os.path.exists(self.human2integer): download(URL, self.human2integer) with open(self.human2integer, "r") as f: lines = f.read().splitlines() @@ -122,11 +126,12 @@ def _load(self): if self.process_images: self.size = retrieve(self.config, "size", default=256) - self.data = ImagePaths(self.abspaths, - labels=labels, - size=self.size, - random_crop=self.random_crop, - ) + self.data = ImagePaths( + self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) else: self.data = self.abspaths @@ -157,8 +162,7 @@ def _prepare(self): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 1281167 - self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", - default=True) + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -166,8 +170,9 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -179,7 +184,7 @@ def _prepare(self): print("Extracting sub-tars.") subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) for subpath in tqdm(subpaths): - subdir = subpath[:-len(".tar")] + subdir = subpath[: -len(".tar")] os.makedirs(subdir, exist_ok=True) with tarfile.open(subpath, "r:") as tar: tar.extractall(path=subdir) @@ -187,7 +192,7 @@ def _prepare(self): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) @@ -222,8 +227,7 @@ def _prepare(self): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 50000 - self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", - default=False) + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -231,8 +235,9 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -242,7 +247,7 @@ def _prepare(self): tar.extractall(path=datadir) vspath = os.path.join(self.root, self.FILES[1]) - if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + if not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1]: download(self.VS_URL, vspath) with open(vspath, "r") as f: @@ -261,18 +266,15 @@ def _prepare(self): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) tdu.mark_prepared(self.root) - class ImageNetSR(Dataset): - def __init__(self, size=None, - degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., - random_crop=True): + def __init__(self, size=None, degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.0, random_crop=True): """ Imagenet Superresolution Dataloader Performs following ops in order: @@ -296,12 +298,12 @@ def __init__(self, size=None, self.LR_size = int(size / downscale_f) self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f - assert(max_crop_f <= 1.) + assert max_crop_f <= 1.0 self.center_crop = not random_crop self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) - self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow if degradation == "bsrgan": self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) @@ -311,17 +313,17 @@ def __init__(self, size=None, else: interpolation_fn = { - "cv_nearest": cv2.INTER_NEAREST, - "cv_bilinear": cv2.INTER_LINEAR, - "cv_bicubic": cv2.INTER_CUBIC, - "cv_area": cv2.INTER_AREA, - "cv_lanczos": cv2.INTER_LANCZOS4, - "pil_nearest": PIL.Image.NEAREST, - "pil_bilinear": PIL.Image.BILINEAR, - "pil_bicubic": PIL.Image.BICUBIC, - "pil_box": PIL.Image.BOX, - "pil_hamming": PIL.Image.HAMMING, - "pil_lanczos": PIL.Image.LANCZOS, + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, }[degradation] self.pil_interpolation = degradation.startswith("pil_") @@ -330,8 +332,9 @@ def __init__(self, size=None, self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) else: - self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, - interpolation=interpolation_fn) + self.degradation_process = albumentations.SmallestMaxSize( + max_size=self.LR_size, interpolation=interpolation_fn + ) def __len__(self): return len(self.base) @@ -366,8 +369,8 @@ def __getitem__(self, i): else: LR_image = self.degradation_process(image=image)["image"] - example["image"] = (image/127.5 - 1.0).astype(np.float32) - example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32) return example @@ -379,7 +382,9 @@ def __init__(self, **kwargs): def get_base(self): with open("data/imagenet_train_hr_indices.p", "rb") as f: indices = pickle.load(f) - dset = ImageNetTrain(process_images=False,) + dset = ImageNetTrain( + process_images=False, + ) return Subset(dset, indices) @@ -390,5 +395,7 @@ def __init__(self, **kwargs): def get_base(self): with open("data/imagenet_val_hr_indices.p", "rb") as f: indices = pickle.load(f) - dset = ImageNetValidation(process_images=False,) + dset = ImageNetValidation( + process_images=False, + ) return Subset(dset, indices) diff --git a/examples/images/diffusion/ldm/data/lsun.py b/examples/images/diffusion/ldm/data/lsun.py index f5bf26c14254..e5c374aa2d51 100644 --- a/examples/images/diffusion/ldm/data/lsun.py +++ b/examples/images/diffusion/ldm/data/lsun.py @@ -1,47 +1,49 @@ import os + import numpy as np import PIL from PIL import Image from torch.utils.data import Dataset from torchvision import transforms + # This class is used to create a dataset of images from LSUN dataset for training class LSUNBase(Dataset): - def __init__(self, - txt_file, # path to the text file containing the list of image paths - data_root, # root directory of the LSUN dataset - size=None, # the size of images to resize to - interpolation="bicubic", # interpolation method to be used while resizing - flip_p=0.5 # probability of random horizontal flipping - ): - self.data_paths = txt_file # store path to text file containing list of images - self.data_root = data_root # store path to root directory of the dataset - with open(self.data_paths, "r") as f: # open and read the text file - self.image_paths = f.read().splitlines() # read the lines of the file and store as list - self._length = len(self.image_paths) # store the number of images - + def __init__( + self, + txt_file, # path to the text file containing the list of image paths + data_root, # root directory of the LSUN dataset + size=None, # the size of images to resize to + interpolation="bicubic", # interpolation method to be used while resizing + flip_p=0.5, # probability of random horizontal flipping + ): + self.data_paths = txt_file # store path to text file containing list of images + self.data_root = data_root # store path to root directory of the dataset + with open(self.data_paths, "r") as f: # open and read the text file + self.image_paths = f.read().splitlines() # read the lines of the file and store as list + self._length = len(self.image_paths) # store the number of images + # create dictionary to hold image path information self.labels = { "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) - for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths], } # set the image size to be resized - self.size = size + self.size = size # set the interpolation method for resizing the image - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] + self.interpolation = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] # randomly flip the image horizontally with a given probability self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): # return the length of dataset return self._length - def __getitem__(self, i): # get the image path for the given index @@ -52,59 +54,71 @@ def __getitem__(self, i): image = image.convert("RGB") # default to score-sde preprocessing - - img = np.array(image).astype(np.uint8) # convert image to numpy array - crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape - h, w, = img.shape[0], img.shape[1] # get the height and width of image - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] # crop the image to a square shape - - image = Image.fromarray(img) # create an image from numpy array - if self.size is not None: # if image size is provided, resize the image + + img = np.array(image).astype(np.uint8) # convert image to numpy array + crop = min(img.shape[0], img.shape[1]) # crop the image to a square shape + ( + h, + w, + ) = ( + img.shape[0], + img.shape[1], + ) # get the height and width of image + img = img[ + (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2 + ] # crop the image to a square shape + + image = Image.fromarray(img) # create an image from numpy array + if self.size is not None: # if image size is provided, resize the image image = image.resize((self.size, self.size), resample=self.interpolation) - image = self.flip(image) # flip the image horizontally with the given probability - image = np.array(image).astype(np.uint8) + image = self.flip(image) # flip the image horizontally with the given probability + image = np.array(image).astype(np.uint8) example["image"] = (image / 127.5 - 1.0).astype(np.float32) # normalize the image values and convert to float32 - return example # return the example dictionary containing the image and its file paths + return example # return the example dictionary containing the image and its file paths + -#A dataset class for LSUN Churches training set. -# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. +# A dataset class for LSUN Churches training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. # The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. Any additional keyword arguments passed to this class will be forwarded to the constructor of the parent class. class LSUNChurchesTrain(LSUNBase): def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) -#A dataset class for LSUN Churches validation set. + +# A dataset class for LSUN Churches validation set. # It is similar to LSUNChurchesTrain except that it uses a different text file and sets the flip probability to zero by default. class LSUNChurchesValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__( + txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", flip_p=flip_p, **kwargs + ) + -# A dataset class for LSUN Bedrooms training set. -# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. +# A dataset class for LSUN Bedrooms training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. class LSUNBedroomsTrain(LSUNBase): def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) -# A dataset class for LSUN Bedrooms validation set. + +# A dataset class for LSUN Bedrooms validation set. # It is similar to LSUNBedroomsTrain except that it uses a different text file and sets the flip probability to zero by default. class LSUNBedroomsValidation(LSUNBase): def __init__(self, flip_p=0.0, **kwargs): - super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", - flip_p=flip_p, **kwargs) + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", flip_p=flip_p, **kwargs) -# A dataset class for LSUN Cats training set. -# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. + +# A dataset class for LSUN Cats training set. +# It initializes by calling the constructor of LSUNBase class and passing the appropriate arguments. # The text file containing the paths to the images and the root directory where the images are stored are passed as arguments. class LSUNCatsTrain(LSUNBase): def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) -# A dataset class for LSUN Cats validation set. + +# A dataset class for LSUN Cats validation set. # It is similar to LSUNCatsTrain except that it uses a different text file and sets the flip probability to zero by default. class LSUNCatsValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", - flip_p=flip_p, **kwargs) + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", flip_p=flip_p, **kwargs) diff --git a/examples/images/diffusion/ldm/data/teyvat.py b/examples/images/diffusion/ldm/data/teyvat.py index eb5d3ea469d4..4a50a78f2dbc 100644 --- a/examples/images/diffusion/ldm/data/teyvat.py +++ b/examples/images/diffusion/ldm/data/teyvat.py @@ -1,15 +1,16 @@ +import json +from pathlib import Path from typing import Dict -import numpy as np -from omegaconf import DictConfig, ListConfig + import torch -from torch.utils.data import Dataset -from pathlib import Path -import json -from PIL import Image -from torchvision import transforms +from datasets import load_dataset from einops import rearrange from ldm.util import instantiate_from_config -from datasets import load_dataset +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders @@ -19,10 +20,9 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): """ list_of_paths = [] if isinstance(paths, (Dict, DictConfig)): - assert caption_files is None, \ - "Caption files not yet supported for repeats" + assert caption_files is None, "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): - list_of_paths.extend([folder_path]*repeats) + list_of_paths.extend([folder_path] * repeats) paths = list_of_paths if caption_files is not None: @@ -31,8 +31,10 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) + class FolderData(Dataset): - def __init__(self, + def __init__( + self, root_dir, caption_file=None, image_transforms=[], @@ -40,7 +42,7 @@ def __init__(self, default_caption="", postprocess=None, return_paths=False, - ) -> None: + ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) @@ -75,12 +77,12 @@ def __init__(self, self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c"))] + ) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms - def __len__(self): if self.captions is not None: return len(self.captions.keys()) @@ -94,7 +96,7 @@ def __getitem__(self, index): caption = self.captions.get(chosen, None) if caption is None: caption = self.default_caption - filename = self.root_dir/chosen + filename = self.root_dir / chosen else: filename = self.paths[index] @@ -119,23 +121,26 @@ def process_im(self, im): im = im.convert("RGB") return self.tform(im) + def hf_dataset( - path = "Fazzie/Teyvat", + path="Fazzie/Teyvat", image_transforms=[], image_column="image", text_column="text", - image_key='image', - caption_key='txt', - ): - """Make huggingface dataset with appropriate list of transforms applied - """ + image_key="image", + caption_key="txt", +): + """Make huggingface dataset with appropriate list of transforms applied""" ds = load_dataset(path, name="train") ds = ds["train"] image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.Resize((256, 256)), - transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] - ) + image_transforms.extend( + [ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2.0 - 1.0, "c h w -> h w c")), + ] + ) tform = transforms.Compose(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" @@ -149,4 +154,4 @@ def pre_process(examples): return processed ds.set_transform(pre_process) - return ds \ No newline at end of file + return ds diff --git a/examples/images/diffusion/ldm/lr_scheduler.py b/examples/images/diffusion/ldm/lr_scheduler.py index be39da9ca6da..f4efb12f28b8 100644 --- a/examples/images/diffusion/ldm/lr_scheduler.py +++ b/examples/images/diffusion/ldm/lr_scheduler.py @@ -5,18 +5,20 @@ class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0. + self.last_lr = 0.0 self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr @@ -24,13 +26,12 @@ def schedule(self, n, **kwargs): else: t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi)) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) self.last_lr = lr return lr def __call__(self, n, **kwargs): - return self.schedule(n,**kwargs) + return self.schedule(n, **kwargs) class LambdaWarmUpCosineScheduler2: @@ -38,6 +39,7 @@ class LambdaWarmUpCosineScheduler2: supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) self.lr_warm_up_steps = warm_up_steps @@ -46,7 +48,7 @@ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosit self.f_max = f_max self.cycle_lengths = cycle_lengths self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0. + self.last_f = 0.0 self.verbosity_interval = verbosity_interval def find_in_interval(self, n): @@ -60,8 +62,8 @@ def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f @@ -69,8 +71,7 @@ def schedule(self, n, **kwargs): else: t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( - 1 + np.cos(t * np.pi)) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) self.last_f = f return f @@ -79,20 +80,20 @@ def __call__(self, n, **kwargs): class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): - def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( + self.cycle_lengths[cycle] + ) self.last_f = f return f - diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py index f0a69fe63a8c..1c54dfe74f74 100644 --- a/examples/images/diffusion/ldm/models/autoencoder.py +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -1,29 +1,28 @@ -import torch -import lightning.pytorch as pl - -from torch import nn -from torch.nn import functional as F -from torch.nn import Identity from contextlib import contextmanager -from ldm.modules.diffusionmodules.model import Encoder, Decoder +import lightning.pytorch as pl +import torch +from ldm.modules.diffusionmodules.model import Decoder, Encoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.ema import LitEma +from torch.nn import Identity +from torch.nn import functional as F class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False - ): + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + ): super().__init__() self.learn_logvar = learn_logvar self.image_key = image_key @@ -31,11 +30,11 @@ def __init__(self, self.decoder = Decoder(**ddconfig) self.loss = Identity() assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim if colorize_nlabels is not None: - assert type(colorize_nlabels)==int + assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor @@ -43,7 +42,7 @@ def __init__(self, self.use_ema = ema_decay is not None if self.use_ema: self.ema_decay = ema_decay - assert 0. < ema_decay < 1. + assert 0.0 < ema_decay < 1.0 self.model_ema = LitEma(self, decay=ema_decay) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") @@ -113,16 +112,30 @@ def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) return aeloss if optimizer_idx == 1: # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) @@ -137,11 +150,25 @@ def validation_step(self, batch, batch_idx): def _validation_step(self, batch, batch_idx, postfix=""): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) self.log_dict(log_dict_ae) @@ -150,15 +177,17 @@ def _validation_step(self, batch, batch_idx, postfix=""): def configure_optimizers(self): lr = self.learning_rate - ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( - self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + ae_params_list = ( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()) + ) if self.learn_logvar: print(f"{self.__class__.__name__}: Learning logvar") ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) return [opt_ae, opt_disc], [] def get_last_layer(self): @@ -195,7 +224,7 @@ def to_rgb(self, x): if not hasattr(self, "colorize"): self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @@ -217,4 +246,3 @@ def quantize(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs): return x - diff --git a/examples/images/diffusion/ldm/models/diffusion/classifier.py b/examples/images/diffusion/ldm/models/diffusion/classifier.py index 3cf12f093bea..73aba26c9d89 100644 --- a/examples/images/diffusion/ldm/models/diffusion/classifier.py +++ b/examples/images/diffusion/ldm/models/diffusion/classifier.py @@ -1,23 +1,21 @@ import os -import torch +from copy import deepcopy +from glob import glob + import lightning.pytorch as pl +import torch +from einops import rearrange +from ldm.lr_scheduler import LambdaLinearScheduler +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import default, ismap, log_txt_as_img +from natsort import natsorted from omegaconf import OmegaConf from torch.nn import functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR -from copy import deepcopy -from einops import rearrange -from glob import glob -from natsort import natsorted -from ldm.models.diffusion.ddpm import LatentDiffusion -from ldm.lr_scheduler import LambdaLinearScheduler -from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel -from ldm.util import log_txt_as_img, default, ismap -__models__ = { - 'class_label': EncoderUNetModel, - 'segmentation': UNetModel -} +__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel} def disabled_train(self, mode=True): @@ -27,24 +25,25 @@ def disabled_train(self, mode=True): class NoisyLatentImageClassifier(pl.LightningModule): - - def __init__(self, - diffusion_path, - num_classes, - ckpt_path=None, - pool='attention', - label_key=None, - diffusion_ckpt_path=None, - scheduler_config=None, - weight_decay=1.e-2, - log_steps=10, - monitor='val/loss', - *args, - **kwargs): + def __init__( + self, + diffusion_path, + num_classes, + ckpt_path=None, + pool="attention", + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.0e-2, + log_steps=10, + monitor="val/loss", + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.num_classes = num_classes # get latest config of diffusion model - diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + diffusion_config = natsorted(glob(os.path.join(diffusion_path, "configs", "*-project.yaml")))[-1] self.diffusion_config = OmegaConf.load(diffusion_config).model self.diffusion_config.params.ckpt_path = diffusion_ckpt_path self.load_diffusion() @@ -54,10 +53,11 @@ def __init__(self, self.log_time_interval = self.diffusion_model.num_timesteps // log_steps self.log_steps = log_steps - self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ - else self.diffusion_model.cond_stage_key + self.label_key = ( + label_key if not hasattr(self.diffusion_model, "cond_stage_key") else self.diffusion_model.cond_stage_key + ) - assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + assert self.label_key is not None, "label_key neither in diffusion model nor in model.params" if self.label_key not in __models__: raise NotImplementedError() @@ -78,8 +78,9 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing Keys: {missing}") @@ -87,7 +88,7 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): print(f"Unexpected Keys: {unexpected}") def load_diffusion(self): - model = LatentDiffusion(**self.diffusion_config.get('params',dict())) + model = LatentDiffusion(**self.diffusion_config.get("params", dict())) self.diffusion_model = model.eval() self.diffusion_model.train = disabled_train for param in self.diffusion_model.parameters(): @@ -97,14 +98,14 @@ def load_classifier(self, ckpt_path, pool): model_config = deepcopy(self.diffusion_config.params.unet_config.params) model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels model_config.out_channels = self.num_classes - if self.label_key == 'class_label': + if self.label_key == "class_label": model_config.pool = pool self.model = __models__[self.label_key](**model_config) if ckpt_path is not None: - print('#####################################################################') + print("#####################################################################") print(f'load from ckpt "{ckpt_path}"') - print('#####################################################################') + print("#####################################################################") self.init_from_ckpt(ckpt_path) @torch.no_grad() @@ -115,8 +116,9 @@ def get_x_noisy(self, x, t, noise=None): continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) # todo: make sure t+1 is correct here - return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, - continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + return self.diffusion_model.q_sample( + x_start=x, t=t, noise=noise, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod + ) def forward(self, x_noisy, t, *args, **kwargs): return self.model(x_noisy, t) @@ -126,7 +128,7 @@ def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') + x = rearrange(x, "b h w c -> b c h w") x = x.to(memory_format=torch.contiguous_format).float() return x @@ -134,15 +136,15 @@ def get_input(self, batch, k): def get_conditioning(self, batch, k=None): if k is None: k = self.label_key - assert k is not None, 'Needs to provide label key' + assert k is not None, "Needs to provide label key" targets = batch[k].to(self.device) - if self.label_key == 'segmentation': - targets = rearrange(targets, 'b h w c -> b c h w') + if self.label_key == "segmentation": + targets = rearrange(targets, "b h w c -> b c h w") for down in range(self.numd): h, w = targets.shape[-2:] - targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest") # targets = rearrange(targets,'b c h w -> b h w c') @@ -157,25 +159,21 @@ def compute_top_k(self, logits, labels, k, reduction="mean"): def on_train_epoch_start(self): # save some memory - self.diffusion_model.model.to('cpu') + self.diffusion_model.model.to("cpu") @torch.no_grad() def write_logs(self, loss, logits, targets): - log_prefix = 'train' if self.training else 'val' + log_prefix = "train" if self.training else "val" log = {} log[f"{log_prefix}/loss"] = loss.mean() - log[f"{log_prefix}/acc@1"] = self.compute_top_k( - logits, targets, k=1, reduction="mean" - ) - log[f"{log_prefix}/acc@5"] = self.compute_top_k( - logits, targets, k=5, reduction="mean" - ) + log[f"{log_prefix}/acc@1"] = self.compute_top_k(logits, targets, k=1, reduction="mean") + log[f"{log_prefix}/acc@5"] = self.compute_top_k(logits, targets, k=5, reduction="mean") self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) - self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) - self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log("global_step", self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]["lr"] + self.log("lr_abs", lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) def shared_step(self, batch, t=None): x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) @@ -189,7 +187,7 @@ def shared_step(self, batch, t=None): x_noisy = self.get_x_noisy(x, t) logits = self(x_noisy, t) - loss = F.cross_entropy(logits, targets, reduction='none') + loss = F.cross_entropy(logits, targets, reduction="none") self.write_logs(loss.detach(), logits.detach(), targets.detach()) @@ -201,8 +199,10 @@ def training_step(self, batch, batch_idx): return loss def reset_noise_accs(self): - self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in - range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + self.noisy_acc = { + t: {"acc@1": [], "acc@5": []} + for t in range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t) + } def on_validation_start(self): self.reset_noise_accs() @@ -213,8 +213,8 @@ def validation_step(self, batch, batch_idx): for t in self.noisy_acc: _, logits, _, targets = self.shared_step(batch, t) - self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) - self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + self.noisy_acc[t]["acc@1"].append(self.compute_top_k(logits, targets, k=1, reduction="mean")) + self.noisy_acc[t]["acc@5"].append(self.compute_top_k(logits, targets, k=5, reduction="mean")) return loss @@ -222,15 +222,12 @@ def configure_optimizers(self): optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) if self.use_scheduler: - scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict())) + scheduler = LambdaLinearScheduler(**self.scheduler_config.get("params", dict())) print("Setting up LambdaLR scheduler...") scheduler = [ - { - 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] + {"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1} + ] return [optimizer], scheduler return optimizer @@ -239,28 +236,28 @@ def configure_optimizers(self): def log_images(self, batch, N=8, *args, **kwargs): log = dict() x = self.get_input(batch, self.diffusion_model.first_stage_key) - log['inputs'] = x + log["inputs"] = x y = self.get_conditioning(batch) - if self.label_key == 'class_label': + if self.label_key == "class_label": y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) - log['labels'] = y + log["labels"] = y if ismap(y): - log['labels'] = self.diffusion_model.to_rgb(y) + log["labels"] = self.diffusion_model.to_rgb(y) for step in range(self.log_steps): current_time = step * self.log_time_interval _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) - log[f'inputs@t{current_time}'] = x_noisy + log[f"inputs@t{current_time}"] = x_noisy pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) - pred = rearrange(pred, 'b h w c -> b c h w') + pred = rearrange(pred, "b h w c -> b c h w") - log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred) for key in log: log[key] = log[key][:N] diff --git a/examples/images/diffusion/ldm/models/diffusion/ddim.py b/examples/images/diffusion/ldm/models/diffusion/ddim.py index 27ead0ea914c..a9e28792f864 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddim.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddim.py @@ -1,11 +1,15 @@ """SAMPLING ONLY.""" -import torch import numpy as np +import torch +from ldm.modules.diffusionmodules.util import ( + extract_into_tensor, + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) from tqdm import tqdm -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor - class DDIMSampler(object): def __init__(self, model, schedule="linear", **kwargs): @@ -20,67 +24,75 @@ def register_buffer(self, name, attr): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] + while isinstance(ctmp, list): + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") @@ -98,35 +110,53 @@ def sample(self, # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) + print(f"Data shape for DDIM sampling is {size}, eta {eta}") + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule, + ) return samples, intermediates @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None): + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ucg_schedule=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -140,12 +170,12 @@ def ddim_sampling(self, cond, shape, subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -154,37 +184,60 @@ def ddim_sampling(self, cond, shape, if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if ucg_schedule is not None: assert len(ucg_schedule) == len(time_range) unconditional_guidance_scale = ucg_schedule[i] - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) return img, intermediates @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): b, *_, device = *x.shape, x.device - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: model_output = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) @@ -194,13 +247,9 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F c_in = dict() for k in c: if isinstance(c[k], list): - c_in[k] = [torch.cat([ - unconditional_conditioning[k][i], - c[k][i]]) for i in range(len(c[k]))] + c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] else: - c_in[k] = torch.cat([ - unconditional_conditioning[k], - c[k]]) + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) elif isinstance(c, list): c_in = list() assert isinstance(unconditional_conditioning, list) @@ -217,18 +266,20 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F e_t = model_output if score_corrector is not None: - assert self.model.parameterization == "eps", 'not implemented' + assert self.model.parameterization == "eps", "not implemented" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + ) sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 if self.model.parameterization != "v": @@ -243,16 +294,25 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F raise NotImplementedError() # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + def encode( + self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + callback=None, + ): num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] assert t_enc <= num_reference_steps @@ -268,33 +328,37 @@ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=No x_next = x0 intermediates = [] inter_steps = [] - for i in tqdm(range(num_steps), desc='Encoding Image'): + for i in tqdm(range(num_steps), desc="Encoding Image"): t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) - if unconditional_guidance_scale == 1.: + if unconditional_guidance_scale == 1.0: noise_pred = self.model.apply_model(x_next, t, c) else: assert unconditional_conditioning is not None e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), - torch.cat((unconditional_conditioning, c))), 2) + self.model.apply_model( + torch.cat((x_next, x_next)), torch.cat((t, t)), torch.cat((unconditional_conditioning, c)) + ), + 2, + ) noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + weighted_noise_pred = ( + alphas_next[i].sqrt() * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + ) x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: + if return_intermediates and i % (num_steps // return_intermediates) == 0 and i < num_steps - 1: intermediates.append(x_next) inter_steps.append(i) elif return_intermediates and i >= num_steps - 2: intermediates.append(x_next) inter_steps.append(i) - if callback: callback(i) + if callback: + callback(i) - out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + out = {"x_encoded": x_next, "intermediate_steps": inter_steps} if return_intermediates: - out.update({'intermediates': intermediates}) + out.update({"intermediates": intermediates}) return x_next, out @torch.no_grad() @@ -310,13 +374,22 @@ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): if noise is None: noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): - + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None, + ): timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = timesteps[:t_start] @@ -324,13 +397,20 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco total_steps = timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) x_dec = x_latent for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec \ No newline at end of file + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + if callback: + callback(i) + return x_dec diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py index 842ec1371ea0..20e26256e18e 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -27,23 +27,22 @@ from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage from ldm.models.diffusion.ddim import * from ldm.models.diffusion.ddim import DDIMSampler -from ldm.modules.midas.api import MiDaSInference from ldm.modules.diffusionmodules.model import * -from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model from ldm.modules.diffusionmodules.openaimodel import * -from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel +from ldm.modules.diffusionmodules.openaimodel import UNetModel +from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl -from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ldm.modules.ema import LitEma from ldm.modules.encoders.modules import * +from ldm.modules.midas.api import MiDaSInference from ldm.util import count_params, default, exists, isimage, ismap, log_txt_as_img, mean_flat from omegaconf import ListConfig from torch.optim.lr_scheduler import LambdaLR from torchvision.utils import make_grid from tqdm import tqdm -__conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} def disabled_train(self, mode=True): @@ -78,15 +77,15 @@ def __init__( linear_end=2e-2, cosine_s=8e-3, given_betas=None, - original_elbo_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules + parameterization="eps", # all assuming fixed variance schedules scheduler_config=None, use_positional_encodings=False, learn_logvar=False, - logvar_init=0., + logvar_init=0.0, use_fp16=True, make_it_fit=False, ucg_training=None, @@ -133,9 +132,9 @@ def __init__( if reset_ema: assert exists(ckpt) - ''' + """ Uncomment if you Use DDP Strategy - ''' + """ # if ckpt is not None: # self.init_from_ckpt(ckpt, ignore_keys=ignore_keys, only_model=load_only_unet) # if reset_ema: @@ -155,12 +154,14 @@ def __init__( self.linear_end = linear_end self.cosine_s = cosine_s - self.register_schedule(given_betas=given_betas, - beta_schedule=beta_schedule, - timesteps=timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s) + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) self.loss_type = loss_type @@ -174,67 +175,73 @@ def __init__( if self.ucg_training: self.ucg_prng = np.random.RandomState() - def register_schedule(self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3): + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): if exists(given_betas): betas = given_betas else: - betas = make_beta_schedule(beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas + posterior_variance = (1 - self.v_posterior) * betas * (1.0 - alphas_cumprod_prev) / ( + 1.0 - alphas_cumprod + ) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', - to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', - to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + self.register_buffer("posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer( + "posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + ) + self.register_buffer( + "posterior_mean_coef2", to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)) + ) if self.parameterization == "eps": - lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + lvlb_weights = self.betas**2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) + ) elif self.parameterization == "x0": - lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) elif self.parameterization == "v": - lvlb_weights = torch.ones_like(self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * - (1 - self.alphas_cumprod))) + lvlb_weights = torch.ones_like( + self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + ) else: raise NotImplementedError("mu not supported") lvlb_weights[0] = lvlb_weights[1] - self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @contextmanager @@ -265,9 +272,11 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): del sd[k] if self.make_it_fit: n_params = len([name for name, _ in itertools.chain(self.named_parameters(), self.named_buffers())]) - for name, param in tqdm(itertools.chain(self.named_parameters(), self.named_buffers()), - desc="Fitting old weights to new weights", - total=n_params): + for name, param in tqdm( + itertools.chain(self.named_parameters(), self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params, + ): if not name in sd: continue old_shape = sd[name].shape @@ -302,8 +311,9 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd[name] = new_param - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: rank_zero_info(f"Missing Keys:\n {missing}") @@ -317,28 +327,36 @@ def q_mean_variance(self, x_start, t): :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ - mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise) + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) def predict_start_from_z_and_v(self, x_t, t, v): # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) def predict_eps_from_z_and_v(self, x_t, t, v): - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) def q_posterior(self, x_start, x_t, t): - posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t) + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped @@ -350,7 +368,7 @@ def p_mean_variance(self, x, t, clip_denoised: bool): elif self.parameterization == "x0": x_recon = model_out if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @@ -370,10 +388,10 @@ def p_sample_loop(self, shape, return_intermediates=False): b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): - img = self.p_sample(img, - torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised) + for i in tqdm(reversed(range(0, self.num_timesteps)), desc="Sampling t", total=self.num_timesteps): + img = self.p_sample( + img, torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised + ) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) if return_intermediates: @@ -384,28 +402,33 @@ def p_sample_loop(self, shape, return_intermediates=False): def sample(self, batch_size=16, return_intermediates=False): image_size = self.image_size channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates) + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), return_intermediates=return_intermediates + ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) def get_v(self, x, noise, t): - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': + if self.loss_type == "l1": loss = (target - pred).abs() if mean: loss = loss.mean() - elif self.loss_type == 'l2': + elif self.loss_type == "l2": if mean: loss = torch.nn.functional.mse_loss(target, pred) else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + loss = torch.nn.functional.mse_loss(target, pred, reduction="none") else: raise NotImplementedError("unknown loss type '{loss_type}'") @@ -428,17 +451,17 @@ def p_losses(self, x_start, t, noise=None): loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) - log_prefix = 'train' if self.training else 'val' + log_prefix = "train" if self.training else "val" - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) loss = loss_simple + self.original_elbo_weight * loss_vlb - loss_dict.update({f'{log_prefix}/loss': loss}) + loss_dict.update({f"{log_prefix}/loss": loss}) return loss, loss_dict @@ -452,7 +475,7 @@ def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') + x = rearrange(x, "b h w c -> b c h w") if self.use_fp16: x = x.to(memory_format=torch.contiguous_format).half() else: @@ -481,8 +504,8 @@ def training_step(self, batch, batch_idx): self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.use_scheduler: - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + lr = self.optimizers().param_groups[0]["lr"] + self.log("lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss @@ -491,7 +514,7 @@ def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) @@ -501,8 +524,8 @@ def on_train_batch_end(self, *args, **kwargs): def _get_rows_from_list(self, samples): n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = rearrange(samples, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @@ -521,7 +544,7 @@ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwarg for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) @@ -556,29 +579,31 @@ def configure_optimizers(self): class LatentDiffusion(DDPM): """main class""" - def __init__(self, - first_stage_config, - cond_stage_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - use_fp16=True, - force_null_conditioning=False, - *args, - **kwargs): + def __init__( + self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + use_fp16=True, + force_null_conditioning=False, + *args, + **kwargs, + ): self.force_null_conditioning = force_null_conditioning self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs['timesteps'] + assert self.num_timesteps_cond <= kwargs["timesteps"] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: - conditioning_key = 'concat' if concat_mode else 'crossattn' - if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__" and not self.force_null_conditioning: conditioning_key = None super().__init__(conditioning_key=conditioning_key, *args, **kwargs) @@ -593,7 +618,7 @@ def __init__(self, if not scale_by_std: self.scale_factor = scale_factor else: - self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.register_buffer("scale_factor", torch.tensor(scale_factor)) self.first_stage_config = first_stage_config self.cond_stage_config = cond_stage_config self.instantiate_first_stage(first_stage_config) @@ -601,9 +626,9 @@ def __init__(self, self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None - ''' + """ Uncomment if you Use DDP Strategy - ''' + """ # self.restarted_from_ckpt = False # if self.ckpt is not None: # self.init_from_ckpt(self.ckpt, self.ignore_keys) @@ -630,15 +655,18 @@ def configure_sharded_model(self) -> None: if self.reset_ema: assert self.use_ema rank_zero_info( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) self.model_ema = LitEma(self.model) - self.register_schedule(given_betas=self.given_betas, - beta_schedule=self.beta_schedule, - timesteps=self.timesteps, - linear_start=self.linear_start, - linear_end=self.linear_end, - cosine_s=self.cosine_s) + self.register_schedule( + given_betas=self.given_betas, + beta_schedule=self.beta_schedule, + timesteps=self.timesteps, + linear_start=self.linear_start, + linear_end=self.linear_end, + cosine_s=self.cosine_s, + ) self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: @@ -654,20 +682,29 @@ def configure_sharded_model(self) -> None: if self.reset_ema: assert self.use_ema rank_zero_info( - f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) self.model_ema = LitEma(self.model) - def make_cond_schedule(self,): + def make_cond_schedule( + self, + ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() - self.cond_ids[:self.num_timesteps_cond] = ids + self.cond_ids[: self.num_timesteps_cond] = ids @rank_zero_only @torch.no_grad() def on_train_batch_start(self, batch, batch_idx): # only for very first batch - if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: - assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + if ( + self.scale_by_std + and self.current_epoch == 0 + and self.global_step == 0 + and batch_idx == 0 + and not self.restarted_from_ckpt + ): + assert self.scale_factor == 1.0, "rather not use custom rescaling and std-rescaling simultaneously" # set rescale weight to 1./std of encodings rank_zero_info("### USING STD-RESCALING ###") x = super().get_input(batch, self.first_stage_key) @@ -675,17 +712,19 @@ def on_train_batch_start(self, batch, batch_idx): encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() del self.scale_factor - self.register_buffer('scale_factor', 1. / z.flatten().std()) + self.register_buffer("scale_factor", 1.0 / z.flatten().std()) rank_zero_info(f"setting self.scale_factor to {self.scale_factor}") rank_zero_info("### USING STD-RESCALING ###") - def register_schedule(self, - given_betas=None, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3): + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) self.shorten_cond_schedule = self.num_timesteps_cond > 1 @@ -718,15 +757,16 @@ def instantiate_cond_stage(self, config): model = FrozenOpenCLIPEmbedder(**config) self.cond_stage_model = model - def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + def _get_denoise_row_from_list(self, samples, desc="", force_no_decoder_quantization=False): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append( - self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization)) + self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization) + ) n_imgs_per_row = len(denoise_row) - denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W - denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @@ -741,7 +781,7 @@ def get_first_stage_encoding(self, encoder_posterior): def get_learned_conditioning(self, c): if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + if hasattr(self.cond_stage_model, "encode") and callable(self.cond_stage_model.encode): c = self.cond_stage_model.encode(c) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() @@ -784,14 +824,17 @@ def get_weighting(self, h, w, Ly, Lx, device): if self.split_input_params["tie_braker"]: L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip(L_weighting, self.split_input_params["clip_min_tie_weight"], - self.split_input_params["clip_max_tie_weight"]) + L_weighting = torch.clip( + L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"], + ) L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) weighting = weighting * L_weighting return weighting - def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) @@ -809,35 +852,39 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load on fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) elif uf > 1 and df == 1: fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, - padding=0, - stride=(stride[0] * uf, stride[1] * uf)) + fold_params2 = dict( + kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf), + ) fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) elif df > 1 and uf == 1: fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) unfold = torch.nn.Unfold(**fold_params) - fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, - padding=0, - stride=(stride[0] // df, stride[1] // df)) + fold_params2 = dict( + kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df), + ) fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) else: @@ -846,15 +893,17 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load on return fold, unfold, normalization, weighting @torch.no_grad() - def get_input(self, - batch, - k, - return_first_stage_outputs=False, - force_c_encode=False, - cond_key=None, - return_original_cond=False, - bs=None, - return_x=False): + def get_input( + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + return_x=False, + ): x = super().get_input(batch, k) if bs is not None: x = x[:bs] @@ -866,9 +915,9 @@ def get_input(self, if cond_key is None: cond_key = self.cond_stage_key if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox', "txt"]: + if cond_key in ["caption", "coordinates_bbox", "txt"]: xc = batch[cond_key] - elif cond_key in ['class_label', 'cls']: + elif cond_key in ["class_label", "cls"]: xc = batch else: xc = super().get_input(batch, cond_key).to(self.device) @@ -887,14 +936,14 @@ def get_input(self, if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) ckey = __conditioning_keys__[self.model.conditioning_key] - c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y} else: c = None xc = None if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) - c = {'pos_x': pos_x, 'pos_y': pos_y} + c = {"pos_x": pos_x, "pos_y": pos_y} out = [z, c] if return_first_stage_outputs: xrec = self.decode_first_stage(z) @@ -912,9 +961,9 @@ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() + z = rearrange(z, "b h w c -> b c h w").contiguous() - z = 1. / self.scale_factor * z + z = 1.0 / self.scale_factor * z return self.first_stage_model.decode(z) @torch.no_grad() @@ -932,7 +981,7 @@ def forward(self, x, c, *args, **kwargs): assert c is not None if self.cond_stage_trainable: c = self.get_learned_conditioning(c) - if self.shorten_cond_schedule: # TODO: drop this option + if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) @@ -944,7 +993,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): else: if not isinstance(cond, list): cond = [cond] - key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + key = "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" cond = {key: cond} x_recon = self.model(x_noisy, t, **cond) @@ -955,8 +1004,9 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): return x_recon def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _prior_bpd(self, x_start): """ @@ -978,7 +1028,7 @@ def p_losses(self, x_start, cond, t, noise=None): model_output = self.apply_model(x_noisy, t, cond) loss_dict = {} - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" if self.parameterization == "x0": target = x_start @@ -990,36 +1040,38 @@ def p_losses(self, x_start, cond, t, noise=None): raise NotImplementedError() loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) + loss_dict.update({f"{prefix}/loss_gamma": loss.mean()}) + loss_dict.update({"logvar": self.logvar.data.mean()}) loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) - loss += (self.original_elbo_weight * loss_vlb) - loss_dict.update({f'{prefix}/loss': loss}) + loss_dict.update({f"{prefix}/loss_vlb": loss_vlb}) + loss += self.original_elbo_weight * loss_vlb + loss_dict.update({f"{prefix}/loss": loss}) return loss, loss_dict - def p_mean_variance(self, - x, - c, - t, - clip_denoised: bool, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - score_corrector=None, - corrector_kwargs=None): + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) @@ -1038,7 +1090,7 @@ def p_mean_variance(self, raise NotImplementedError() if clip_denoised: - x_recon.clamp_(-1., 1.) + x_recon.clamp_(-1.0, 1.0) if quantize_denoised: x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) @@ -1050,29 +1102,33 @@ def p_mean_variance(self, return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, - x, - c, - t, - clip_denoised=False, - repeat_noise=False, - return_codebook_ids=False, - quantize_denoised=False, - return_x0=False, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None): + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance(x=x, - c=c, - t=t, - clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs) + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if return_codebook_ids: raise DeprecationWarning("Support dropped.") model_mean, _, model_log_variance, logits = outputs @@ -1082,7 +1138,7 @@ def p_sample(self, model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) @@ -1095,23 +1151,25 @@ def p_sample(self, return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def progressive_denoising(self, - cond, - shape, - verbose=True, - callback=None, - quantize_denoised=False, - img_callback=None, - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - batch_size=None, - x_T=None, - start_T=None, - log_every_t=None): + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps @@ -1128,40 +1186,47 @@ def progressive_denoising(self, if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( - map(lambda x: x[:batch_size], cond[key])) for key in cond + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', - total=timesteps) if verbose else reversed(range(0, timesteps)) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Progressive Generation", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) if type(temperature) == float: temperature = [temperature] * timesteps for i in iterator: ts = torch.full((b,), i, device=self.device, dtype=torch.long) if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' + assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - img, x0_partial = self.p_sample(img, - cond, - ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, - return_x0=True, - temperature=temperature[i], - noise_dropout=noise_dropout, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs) + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) @@ -1172,21 +1237,22 @@ def progressive_denoising(self, return img, intermediates @torch.no_grad() - def p_sample_loop(self, - cond, - shape, - return_intermediates=False, - x_T=None, - verbose=True, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - start_T=None, - log_every_t=None): - + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): if not log_every_t: log_every_t = self.log_every_t device = self.betas.device @@ -1202,24 +1268,27 @@ def p_sample_loop(self, if start_T is not None: timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( - range(0, timesteps)) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) if mask is not None: assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' + assert self.model.conditioning_key != "hybrid" tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised) if mask is not None: img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img + img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) @@ -1233,37 +1302,43 @@ def p_sample_loop(self, return img @torch.no_grad() - def sample(self, - cond, - batch_size=16, - return_intermediates=False, - x_T=None, - verbose=True, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - shape=None, - **kwargs): + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] if not isinstance(cond[key], list) else list( - map(lambda x: x[:batch_size], cond[key])) for key in cond + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond } else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - return self.p_sample_loop(cond, - shape, - return_intermediates=return_intermediates, - x_T=x_T, - verbose=verbose, - timesteps=timesteps, - quantize_denoised=quantize_denoised, - mask=mask, - x0=x0) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + ) @torch.no_grad() def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): @@ -1295,41 +1370,45 @@ def get_unconditional_conditioning(self, batch_size, null_label=None): return self.get_learned_conditioning(xc) else: raise NotImplementedError("todo") - if isinstance(c, list): # in case the encoder gives us a list + if isinstance(c, list): # in case the encoder gives us a list for i in range(len(c)): - c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + c[i] = repeat(c[i], "1 ... -> b ...", b=batch_size).to(self.device) else: - c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) return c @torch.no_grad() - def log_images(self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=50, - ddim_eta=0., - return_keys=None, - quantize_denoised=True, - inpaint=True, - plot_denoise_rows=False, - plot_progressive_rows=True, - plot_diffusion_rows=True, - unconditional_guidance_scale=1., - unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=50, + ddim_eta=0.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc = self.get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N) + z, c, x, xrec, xc = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, + ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x @@ -1341,10 +1420,10 @@ def log_images(self, elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', "cls"]: + elif self.cond_stage_key in ["class_label", "cls"]: try: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc except KeyError: # probably no "human_label" in batch pass @@ -1359,26 +1438,24 @@ def log_images(self, z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1386,16 +1463,16 @@ def log_images(self, denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid - if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( - self.first_stage_model, IdentityFirstStage): + if ( + quantize_denoised + and not isinstance(self.first_stage_model, AutoencoderKL) + and not isinstance(self.first_stage_model, IdentityFirstStage) + ): # also display when quantizing x0 while sampling with ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta, - quantize_denoised=True) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, quantize_denoised=True + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device)) @@ -1423,38 +1500,30 @@ def log_images(self, b, h, w = z.shape[0], z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in - mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 mask = mask[:, None, ...] with ema_scope("Plotting Inpaint"): - samples, _ = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask) + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_inpainting"] = x_samples log["mask"] = mask # outpaint - mask = 1. - mask + mask = 1.0 - mask with ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - eta=ddim_eta, - ddim_steps=ddim_steps, - x0=z[:N], - mask=mask) + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_outpainting"] = x_samples if plot_progressive_rows: with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) + img, progressives = self.progressive_denoising( + c, shape=(self.channels, self.image_size, self.image_size), batch_size=N + ) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row @@ -1472,10 +1541,11 @@ def configure_optimizers(self): rank_zero_info(f"{self.__class__.__name__}: Also optimizing conditioner params!") params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: - rank_zero_info('Diffusion model optimizing logvar') + rank_zero_info("Diffusion model optimizing logvar") params.append(self.logvar) from colossalai.nn.optimizer import HybridAdam + opt = HybridAdam(params, lr=lr) # opt = torch.optim.AdamW(params, lr=lr) @@ -1483,7 +1553,7 @@ def configure_optimizers(self): scheduler = LambdaLinearScheduler(**self.scheduler_config) rank_zero_info("Setting up LambdaLR scheduler...") - scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] + scheduler = [{"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1}] return [opt], scheduler return opt @@ -1493,45 +1563,44 @@ def to_rgb(self, x): if not hasattr(self, "colorize"): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) - x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x class DiffusionWrapper(pl.LightningModule): - def __init__(self, diff_model_config, conditioning_key): super().__init__() self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) self.diffusion_model = UNetModel(**diff_model_config) self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + assert self.conditioning_key in [None, "concat", "crossattn", "hybrid", "adm", "hybrid-adm", "crossattn-adm"] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): if self.conditioning_key is None: out = self.diffusion_model(x, t) - elif self.conditioning_key == 'concat': + elif self.conditioning_key == "concat": xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t) - elif self.conditioning_key == 'crossattn': + elif self.conditioning_key == "crossattn": if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) else: cc = c_crossattn out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == 'hybrid': + elif self.conditioning_key == "hybrid": xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == 'hybrid-adm': + elif self.conditioning_key == "hybrid-adm": assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc, y=c_adm) - elif self.conditioning_key == 'crossattn-adm': + elif self.conditioning_key == "crossattn-adm": assert c_adm is not None cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc, y=c_adm) - elif self.conditioning_key == 'adm': + elif self.conditioning_key == "adm": cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) else: @@ -1541,7 +1610,6 @@ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=N class LatentUpscaleDiffusion(LatentDiffusion): - def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): super().__init__(*args, **kwargs) # assumes that neither the cond_stage nor the low_scale_model contain trainable params @@ -1562,14 +1630,16 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): if not log_mode: z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) else: - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) x_low = batch[self.low_scale_key][:bs] - x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = rearrange(x_low, "b h w c -> b c h w") if self.use_fp16: x_low = x_low.to(memory_format=torch.contiguous_format).half() else: @@ -1577,7 +1647,7 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): zx, noise_level = self.low_scale_model(x_low) if self.noise_level_key is not None: # get noise level from batch instead, e.g. when extracting a custom noise level for bsr - raise NotImplementedError('TODO') + raise NotImplementedError("TODO") all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} if log_mode: @@ -1587,29 +1657,30 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): return z, all_conds @torch.no_grad() - def log_images(self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=200, - ddim_eta=1., - return_keys=None, - plot_denoise_rows=False, - plot_progressive_rows=True, - plot_diffusion_rows=True, - unconditional_guidance_scale=1., - unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None log = dict() - z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, - self.first_stage_key, - bs=N, - log_mode=True) + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input( + batch, self.first_stage_key, bs=N, log_mode=True + ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x @@ -1623,9 +1694,9 @@ def log_images(self, elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: + elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): @@ -1637,26 +1708,24 @@ def log_images(self, z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond=c, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1673,9 +1742,9 @@ def log_images(self, if k == "c_crossattn": assert isinstance(c[k], list) and len(c[k]) == 1 uc[k] = [uc_tmp] - elif k == "c_adm": # todo: only run with text-based guidance? + elif k == "c_adm": # todo: only run with text-based guidance? assert isinstance(c[k], torch.Tensor) - #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + # uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level uc[k] = c[k] elif isinstance(c[k], list): uc[k] = [c[k][i] for i in range(len(c[k]))] @@ -1697,9 +1766,9 @@ def log_images(self, if plot_progressive_rows: with ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) + img, progressives = self.progressive_denoising( + c, shape=(self.channels, self.image_size, self.image_size), batch_size=N + ) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row @@ -1708,21 +1777,24 @@ def log_images(self, class LatentFinetuneDiffusion(LatentDiffusion): """ - Basis for different finetunas, such as inpainting or depth2image - To disable finetuning mode, set finetune_keys to None + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None """ def __init__( - self, - concat_keys: tuple, - finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", - "model_ema.diffusion_modelinput_blocks00weight"), - keep_finetune_dims=4, - # if model was trained without concat mode before and we would like to keep these channels - c_concat_log_start=None, # to log reconstruction of c_concat codes - c_concat_log_end=None, - *args, - **kwargs): + self, + concat_keys: tuple, + finetune_keys=( + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, + **kwargs, + ): ckpt = kwargs.pop("ckpt", None) ignore_keys = kwargs.pop("ignore_keys", list()) super().__init__(*args, **kwargs) @@ -1732,7 +1804,7 @@ def __init__( self.c_concat_log_start = c_concat_log_start self.c_concat_log_end = c_concat_log_end if exists(self.finetune_keys): - assert exists(ckpt), 'can only finetune from a given checkpoint' + assert exists(ckpt), "can only finetune from a given checkpoint" if exists(ckpt): self.init_from_ckpt(ckpt, ignore_keys) @@ -1755,13 +1827,14 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): rank_zero_info( f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" ) - new_entry = torch.zeros_like(param) # zero init - assert exists(new_entry), 'did not find matching parameter to modify' - new_entry[:, :self.keep_dims, ...] = sd[k] + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), "did not find matching parameter to modify" + new_entry[:, : self.keep_dims, ...] = sd[k] sd[k] = new_entry - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: rank_zero_info(f"Missing Keys: {missing}") @@ -1769,23 +1842,25 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): rank_zero_info(f"Unexpected Keys: {unexpected}") @torch.no_grad() - def log_images(self, - batch, - N=8, - n_row=4, - sample=True, - ddim_steps=200, - ddim_eta=1., - return_keys=None, - quantize_denoised=True, - inpaint=True, - plot_denoise_rows=False, - plot_progressive_rows=True, - plot_diffusion_rows=True, - unconditional_guidance_scale=1., - unconditional_guidance_label=None, - use_ema_scope=True, - **kwargs): + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None @@ -1803,16 +1878,16 @@ def log_images(self, elif self.cond_stage_key in ["caption", "txt"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) log["conditioning"] = xc - elif self.cond_stage_key in ['class_label', 'cls']: + elif self.cond_stage_key in ["class_label", "cls"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) - log['conditioning'] = xc + log["conditioning"] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): log["original_conditioning"] = self.to_rgb(xc) if not (self.c_concat_log_start is None and self.c_concat_log_end is None): - log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start : self.c_concat_log_end]) if plot_diffusion_rows: # get diffusion row @@ -1820,29 +1895,28 @@ def log_images(self, z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with ema_scope("Sampling"): - samples, z_denoise_row = self.sample_log(cond={ - "c_concat": [c_cat], - "c_crossattn": [c] - }, - batch_size=N, - ddim=use_ddim, - ddim_steps=ddim_steps, - eta=ddim_eta) + samples, z_denoise_row = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples @@ -1856,10 +1930,7 @@ def log_images(self, uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} with ema_scope("Sampling with classifier-free guidance"): samples_cfg, _ = self.sample_log( - cond={ - "c_concat": [c_cat], - "c_crossattn": [c] - }, + cond={"c_concat": [c_cat], "c_crossattn": [c]}, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, @@ -1878,7 +1949,7 @@ class LatentInpaintDiffusion(LatentFinetuneDiffusion): can either run as pure inpainting model (only concat mode) or with mixed conditionings, e.g. mask as concat and text via cross-attn. To disable finetuning mode, set finetune_keys to None - """ + """ def __init__(self, concat_keys=("mask", "masked_image"), masked_image_key="masked_image", *args, **kwargs): super().__init__(concat_keys, *args, **kwargs) @@ -1888,21 +1959,23 @@ def __init__(self, concat_keys=("mask", "masked_image"), masked_image_key="maske @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for inpainting" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) c_cat = list() for ck in self.concat_keys: if self.use_fp16: - cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).half() + cc = rearrange(batch[ck], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).half() else: - cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + cc = rearrange(batch[ck], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).float() if bs is not None: cc = cc[:bs] cc = cc.to(self.device) @@ -1921,8 +1994,9 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs @torch.no_grad() def log_images(self, *args, **kwargs): log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) - log["masked_image"] = rearrange(args[0]["masked_image"], - 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + log["masked_image"] = ( + rearrange(args[0]["masked_image"], "b h w c -> b c h w").to(memory_format=torch.contiguous_format).float() + ) return log @@ -1939,13 +2013,15 @@ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwarg @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for depth2img" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -1963,10 +2039,10 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs align_corners=False, ) - depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, - dim=[1, 2, 3], - keepdim=True) - cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax( + cc, dim=[1, 2, 3], keepdim=True + ) + cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0 c_cat.append(cc) c_cat = torch.cat(c_cat, dim=1) all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} @@ -1978,24 +2054,21 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) depth = self.depth_model(args[0][self.depth_stage_key]) - depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ - torch.amax(depth, dim=[1, 2, 3], keepdim=True) - log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), torch.amax( + depth, dim=[1, 2, 3], keepdim=True + ) + log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0 return log class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): """ - condition on low-res image (and optionally on some spatial noise augmentation) + condition on low-res image (and optionally on some spatial noise augmentation) """ - def __init__(self, - concat_keys=("lr",), - reshuffle_patch_size=None, - low_scale_config=None, - low_scale_key=None, - *args, - **kwargs): + def __init__( + self, concat_keys=("lr",), reshuffle_patch_size=None, low_scale_config=None, low_scale_key=None, *args, **kwargs + ): super().__init__(concat_keys=concat_keys, *args, **kwargs) self.reshuffle_patch_size = reshuffle_patch_size self.low_scale_model = None @@ -2015,13 +2088,15 @@ def instantiate_low_stage(self, config): @torch.no_grad() def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): # note: restricted to non-trainable encoders currently - assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' - z, c, x, xrec, xc = super().get_input(batch, - self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=bs) + assert not self.cond_stage_trainable, "trainable cond stages not yet supported for upscaling-ft" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) assert exists(self.concat_keys) assert len(self.concat_keys) == 1 @@ -2030,13 +2105,15 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs noise_level = None for ck in self.concat_keys: cc = batch[ck] - cc = rearrange(cc, 'b h w c -> b c h w') + cc = rearrange(cc, "b h w c -> b c h w") if exists(self.reshuffle_patch_size): assert isinstance(self.reshuffle_patch_size, int) - cc = rearrange(cc, - 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', - p1=self.reshuffle_patch_size, - p2=self.reshuffle_patch_size) + cc = rearrange( + cc, + "b c (p1 h) (p2 w) -> b (p1 p2 c) h w", + p1=self.reshuffle_patch_size, + p2=self.reshuffle_patch_size, + ) if bs is not None: cc = cc[:bs] cc = cc.to(self.device) @@ -2055,5 +2132,5 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs @torch.no_grad() def log_images(self, *args, **kwargs): log = super().log_images(*args, **kwargs) - log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w") return log diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py index 7427f38c0753..f56611cb5fb3 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py @@ -1 +1 @@ -from .sampler import DPMSolverSampler \ No newline at end of file +from .sampler import DPMSolverSampler diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba3ce0b..66063320ec78 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -1,17 +1,17 @@ -import torch -import torch.nn.functional as F import math + +import torch from tqdm import tqdm class NoiseScheduleVP: def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, ): """Create a wrapper class for the forward SDE (VP type). *** @@ -70,50 +70,63 @@ def __init__( >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) """ - if schedule not in ['discrete', 'linear', 'cosine']: + if schedule not in ["discrete", "linear", "cosine"]: raise ValueError( "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule)) + schedule + ) + ) self.schedule = schedule - if schedule == 'discrete': + if schedule == "discrete": if betas is not None: log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) else: assert alphas_cumprod is not None log_alphas = 0.5 * torch.log(alphas_cumprod) self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) + self.T = 1.0 + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape( + ( + 1, + -1, + ) + ) else: self.total_N = 1000 self.beta_0 = continuous_beta_0 self.beta_1 = continuous_beta_1 self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.cosine_beta_max = 999.0 + self.cosine_t_max = ( + math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) self.schedule = schedule - if schedule == 'cosine': + if schedule == "cosine": # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. self.T = 0.9946 else: - self.T = 1. + self.T = 1.0 def marginal_log_mean_coeff(self, t): """ Compute log(alpha_t) of a given continuous-time label t in [0, T]. """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), - self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device) + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == "cosine": + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 return log_alpha_t @@ -127,48 +140,56 @@ def marginal_std(self, t): """ Compute sigma_t of a given continuous-time label t in [0, T]. """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) def marginal_lambda(self, t): """ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. """ log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) return log_mean_coeff - log_std def inverse_lambda(self, lamb): """ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1])) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) return t.reshape((-1,)) else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s + log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + t_fn = ( + lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) t = t_fn(log_alpha) return t def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, ): """Create a wrapper function for the noise prediction model. DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to @@ -249,8 +270,8 @@ def get_model_input_time(t_continuous): For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. For continuous-time DPMs, we just use `t_continuous`. """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * 1000. + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 else: return t_continuous @@ -302,7 +323,7 @@ def model_fn(x, t_continuous): noise = noise_pred_fn(x, t_continuous) return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: + if guidance_scale == 1.0 or unconditional_condition is None: return noise_pred_fn(x, t_continuous, cond=condition) else: x_in = torch.cat([x] * 2) @@ -317,7 +338,7 @@ def model_fn(x, t_continuous): class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.0): """Construct a DPM-Solver. We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). @@ -387,20 +408,21 @@ def get_time_steps(self, skip_type, t_T, t_0, N, device): Returns: A pytorch tensor of the time steps, with the shape (N + 1,). """ - if skip_type == 'logSNR': + if skip_type == "logSNR": lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': + elif skip_type == "time_uniform": return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': + elif skip_type == "time_quadratic": t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) return t else: raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ @@ -435,29 +457,57 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type if order == 3: K = steps // 3 + 1 if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] + orders = [ + 3, + ] * ( + K - 1 + ) + [1] else: - orders = [3, ] * (K - 1) + [2] + orders = [ + 3, + ] * ( + K - 1 + ) + [2] elif order == 2: if steps % 2 == 0: K = steps // 2 - orders = [2, ] * K + orders = [ + 2, + ] * K else: K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] + orders = [ + 2, + ] * ( + K - 1 + ) + [1] elif order == 1: K = 1 - orders = [1, ] * steps + orders = [ + 1, + ] * steps else: raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': + if skip_type == "logSNR": # To reproduce the results in DPM-Solver paper timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) else: timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ) + ).to(device) + ] return timesteps_outer, orders def denoise_to_zero_fn(self, x, s): @@ -491,12 +541,9 @@ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=Fal phi_1 = torch.expm1(-h) if model_s is None: model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) + x_t = expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s if return_intermediate: - return x_t, {'model_s': model_s} + return x_t, {"model_s": model_s} else: return x_t else: @@ -504,16 +551,17 @@ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=Fal if model_s is None: model_s = self.model_fn(x, s) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s ) if return_intermediate: - return x_t, {'model_s': model_s} + return x_t, {"model_s": model_s} else: return x_t - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, - solver_type='dpm_solver'): + def singlestep_dpm_solver_second_update( + self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpm_solver" + ): """ Singlestep solver DPM-Solver-2 from time `s` to time `t`. Args: @@ -529,7 +577,7 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) if r1 is None: r1 = 0.5 @@ -539,8 +587,11 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret h = lambda_t - lambda_s lambda_s1 = lambda_s + r1 * h s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( - s1), ns.marginal_log_mean_coeff(t) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) @@ -550,23 +601,19 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret if model_s is None: model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1.0 / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * (model_s1 - model_s) ) else: phi_11 = torch.expm1(r1 * h) @@ -575,29 +622,39 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret if model_s is None: model_s = self.model_fn(x, s) x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s ) model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1.0 / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * (model_s1 - model_s) ) if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1} + return x_t, {"model_s": model_s, "model_s1": model_s1} else: return x_t - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpm_solver", + ): """ Singlestep solver DPM-Solver-3 from time `s` to time `t`. Args: @@ -616,12 +673,12 @@ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., mo Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) if r1 is None: - r1 = 1. / 3. + r1 = 1.0 / 3.0 if r2 is None: - r2 = 2. / 3. + r2 = 2.0 / 3.0 ns = self.noise_schedule dims = x.dim() lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) @@ -630,93 +687,98 @@ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., mo lambda_s2 = lambda_s + r2 * h s1 = ns.inverse_lambda(lambda_s1) s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( - s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( - s2), ns.marginal_std(t) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) if self.predict_x0: phi_11 = torch.expm1(-r1 * h) phi_12 = torch.expm1(-r2 * h) phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. - phi_2 = phi_1 / h + 1. + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 phi_3 = phi_2 / h - 0.5 if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1.0 / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 ) else: phi_11 = torch.expm1(r1 * h) phi_12 = torch.expm1(r2 * h) phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. - phi_2 = phi_1 / h - 1. + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 phi_3 = phi_2 / h - 0.5 if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s ) model_s1 = self.model_fn(x_s1, s1) x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) ) model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1.0 / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 ) if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} else: return x_t @@ -733,14 +795,17 @@ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, Returns: x_t: A pytorch tensor. The approximated solution at time `t`. """ - if solver_type not in ['dpm_solver', 'taylor']: + if solver_type not in ["dpm_solver", "taylor"]: raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) ns = self.noise_schedule dims = x.dim() model_prev_1, model_prev_0 = model_prev_list t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( - t_prev_0), ns.marginal_lambda(t) + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -748,36 +813,36 @@ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) if self.predict_x0: - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0 ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1_0 ) else: - if solver_type == 'dpm_solver': + if solver_type == "dpm_solver": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0 ) - elif solver_type == 'taylor': + elif solver_type == "taylor": x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1_0 ) return x_t - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): """ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. Args: @@ -794,8 +859,12 @@ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, dims = x.dim() model_prev_2, model_prev_1, model_prev_0 = model_prev_list t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( - t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -804,28 +873,29 @@ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, h_0 = lambda_prev_0 - lambda_prev_1 h = lambda_t - lambda_prev_0 r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2) D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1) if self.predict_x0: x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims) * D2 ) else: x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims) * D2 ) return x_t - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, - r2=None): + def singlestep_dpm_solver_update( + self, x, s, t, order, return_intermediate=False, solver_type="dpm_solver", r1=None, r2=None + ): """ Singlestep DPM-Solver with the order `order` from time `s` to time `t`. Args: @@ -844,15 +914,17 @@ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False if order == 1: return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) + return self.singlestep_dpm_solver_second_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1 + ) elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) + return self.singlestep_dpm_solver_third_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2 + ) else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"): """ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. Args: @@ -875,8 +947,9 @@ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): + def dpm_solver_adaptive( + self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpm_solver" + ): """ The adaptive step size solver based on singlestep DPM-Solver. Args: @@ -906,17 +979,17 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol if order == 2: r1 = 0.5 lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, solver_type=solver_type, **kwargs + ) elif order == 3: - r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( + x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs + ) else: raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) while torch.abs((s - t_0)).mean() > t_err: @@ -926,20 +999,31 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.): + if torch.all(E <= 1.0): x = x_higher s = t x_prev = x_lower lambda_s = ns.marginal_lambda(s) - h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s) nfe += order - print('adaptive solver nfe', nfe) + print("adaptive solver nfe", nfe) return x - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, - ): + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=3, + skip_type="time_uniform", + method="singlestep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpm_solver", + atol=0.0078, + rtol=0.05, + ): """ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. ===================================================== @@ -1034,14 +1118,15 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time Returns: x_end: A pytorch tensor. The approximated solution at time `t_end`. """ - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start device = x.device - if method == 'adaptive': + if method == "adaptive": with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) - elif method == 'multistep': + x = self.dpm_solver_adaptive( + x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type + ) + elif method == "multistep": assert steps >= order timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) assert timesteps.shape[0] - 1 == steps @@ -1052,8 +1137,9 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # Init the first `order` values by lower order multistep DPM-Solver. for init_order in tqdm(range(1, order), desc="DPM init order"): vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, - solver_type=solver_type) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type + ) model_prev_list.append(self.model_fn(x, vec_t)) t_prev_list.append(vec_t) # Compute the remaining values by `order`-th order multistep DPM-Solver. @@ -1063,8 +1149,9 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time step_order = min(order, steps + 1 - step) else: step_order = order - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, - solver_type=solver_type) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type + ) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] @@ -1072,20 +1159,22 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # We do not need to evaluate the final model value. if step < steps: model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, - skip_type=skip_type, - t_T=t_T, t_0=t_0, - device=device) - elif method == 'singlestep_fixed': + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device + ) + elif method == "singlestep_fixed": K = steps // order - orders = [order, ] * K + orders = [ + order, + ] * K timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) for i, order in enumerate(orders): t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) + timesteps_inner = self.get_time_steps( + skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device + ) lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) h = lambda_inner[-1] - lambda_inner[0] @@ -1101,6 +1190,7 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # other utility functions ############################################################# + def interpolate_fn(x, xp, yp): """ A piecewise linear function y = f(x), using xp and yp as keypoints. @@ -1122,7 +1212,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(1, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) @@ -1132,7 +1224,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(0, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) @@ -1151,4 +1245,4 @@ def expand_dims(v, dims): Returns: a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file + return v[(...,) + (None,) * (dims - 1)] diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8cf367..55dac8555e5f 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -1,13 +1,9 @@ """SAMPLING ONLY.""" import torch -from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver +from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper - -MODEL_TYPES = { - "eps": "noise", - "v": "v" -} +MODEL_TYPES = {"eps": "noise", "v": "v"} class DPMSolverSampler(object): @@ -15,7 +11,7 @@ def __init__(self, model, **kwargs): super().__init__() self.model = model to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) - self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod)) def register_buffer(self, name, attr): if type(attr) == torch.Tensor: @@ -24,30 +20,31 @@ def register_buffer(self, name, attr): setattr(self, name, attr) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -61,7 +58,7 @@ def sample(self, C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + print(f"Data shape for DPM-Solver sampling is {size}, sampling steps {S}") device = self.model.betas.device if x_T is None: @@ -69,7 +66,7 @@ def sample(self, else: img = x_T - ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) model_fn = model_wrapper( lambda x, t, c: self.model.apply_model(x, t, c), @@ -82,6 +79,8 @@ def sample(self, ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + x = dpm_solver.sample( + img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True + ) - return x.to(device), None \ No newline at end of file + return x.to(device), None diff --git a/examples/images/diffusion/ldm/models/diffusion/plms.py b/examples/images/diffusion/ldm/models/diffusion/plms.py index 7002a365d271..b2b3f032e491 100644 --- a/examples/images/diffusion/ldm/models/diffusion/plms.py +++ b/examples/images/diffusion/ldm/models/diffusion/plms.py @@ -1,12 +1,10 @@ """SAMPLING ONLY.""" -import torch import numpy as np -from tqdm import tqdm -from functools import partial - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +import torch from ldm.models.diffusion.sampling_util import norm_thresholding +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from tqdm import tqdm class PLMSSampler(object): @@ -22,65 +20,72 @@ def register_buffer(self, name, attr): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): if ddim_eta != 0: - raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + raise ValueError("ddim_eta must be 0 for PLMS") + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1))) # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs, + ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -94,34 +99,51 @@ def sample(self, # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) + print(f"Data shape for PLMS sampling is {size}") + + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) return samples, intermediates @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): + def plms_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + ): device = self.model.betas.device b = shape[0] if x_T is None: @@ -135,12 +157,12 @@ def plms_sampling(self, cond, shape, subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running PLMS Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) old_eps = [] for i, step in enumerate(iterator): @@ -151,38 +173,64 @@ def plms_sampling(self, cond, shape, if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next, - dynamic_threshold=dynamic_threshold) + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_plms( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + dynamic_threshold=dynamic_threshold, + ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) return img, intermediates @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, - dynamic_threshold=None): + def p_sample_plms( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + dynamic_threshold=None, + ): b, *_, device = *x.shape, x.device def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) @@ -199,7 +247,9 @@ def get_model_output(x, t): alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + ) sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas def get_x_prev_and_pred_x0(e_t, index): @@ -207,7 +257,7 @@ def get_x_prev_and_pred_x0(e_t, index): a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() @@ -216,9 +266,9 @@ def get_x_prev_and_pred_x0(e_t, index): if dynamic_threshold is not None: pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: + if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 diff --git a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py index 7eff02be6d7c..a4681368112b 100644 --- a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py +++ b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py @@ -1,13 +1,9 @@ -import torch -import numpy as np - - def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions. From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] @@ -19,4 +15,4 @@ def norm_thresholding(x0, value): def spatial_norm_thresholding(x0, value): # b c h w s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) - return x0 * (value / s) \ No newline at end of file + return x0 * (value / s) diff --git a/examples/images/diffusion/ldm/modules/attention.py b/examples/images/diffusion/ldm/modules/attention.py index d504d939f6a0..f3c385e5138f 100644 --- a/examples/images/diffusion/ldm/modules/attention.py +++ b/examples/images/diffusion/ldm/modules/attention.py @@ -1,17 +1,17 @@ -from inspect import isfunction import math +from inspect import isfunction +from typing import Any, Optional + import torch import torch.nn.functional as F -from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional, Any - from ldm.modules.diffusionmodules.util import checkpoint - +from torch import einsum, nn try: import xformers import xformers.ops + XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False @@ -22,7 +22,7 @@ def exists(val): def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() def default(val, d): @@ -54,20 +54,13 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -92,26 +85,10 @@ def __init__(self, in_channels): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -121,41 +98,38 @@ def forward(self, x): v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) - w_ = w_ * (int(c)**(-0.5)) + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) h_ = self.proj_out(h_) - return x+h_ + return x + h_ class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def forward(self, x, context=None, mask=None): h = self.heads @@ -165,22 +139,22 @@ def forward(self, x, context=None, mask=None): k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale del q, k if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') + mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) + mask = repeat(mask, "b j -> (b h) () j", h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', sim, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = einsum("b i j, b j d -> b i d", sim, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) @@ -188,8 +162,10 @@ class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() - print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " - f"{heads} heads.") + print( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads." + ) inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -236,20 +212,36 @@ def forward(self, x, context=None, mask=None): class BasicTransformerBlock(nn.Module): ATTENTION_MODES = { "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention + "softmax-xformers": MemoryEfficientCrossAttention, } - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False): + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + ): super().__init__() attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" assert attn_mode in self.ATTENTION_MODES attn_cls = self.ATTENTION_MODES[attn_mode] self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn2 = attn_cls( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) @@ -274,10 +266,19 @@ class SpatialTransformer(nn.Module): Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, - disable_self_attn=False, use_linear=False, - use_checkpoint=True): + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + ): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] @@ -285,25 +286,26 @@ def __init__(self, in_channels, n_heads, d_head, inner_dim = n_heads * d_head self.norm = Normalize(in_channels) if not use_linear: - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) else: self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) - for d in range(depth)] + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + ) + for d in range(depth) + ] ) if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) else: self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear @@ -317,15 +319,14 @@ def forward(self, x, context=None): x = self.norm(x) if not self.use_linear: x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + x = rearrange(x, "b c h w -> b (h w) c").contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): x = block(x, context=context[i]) if self.use_linear: x = self.proj_out(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in - diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py index fb088db58919..7ed8d98a44ad 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py @@ -17,6 +17,7 @@ try: import xformers import xformers.ops + XFORMERS_IS_AVAILBLE = True except: XFORMERS_IS_AVAILBLE = False @@ -39,7 +40,7 @@ def get_timestep_embedding(timesteps, embedding_dim): emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad + if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb @@ -54,7 +55,6 @@ def Normalize(in_channels, num_groups=32): class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv @@ -69,7 +69,6 @@ def forward(self, x): class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv @@ -88,7 +87,6 @@ def forward(self, x): class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels @@ -133,7 +131,6 @@ def forward(self, x, temb): class AttnBlock(nn.Module): - def __init__(self, in_channels): super().__init__() self.in_channels = in_channels @@ -154,16 +151,16 @@ def forward(self, x): # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) @@ -173,9 +170,9 @@ def forward(self, x): class MemoryEfficientAttnBlock(nn.Module): """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation """ # @@ -199,34 +196,41 @@ def forward(self, x): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) q, k, v = map( - lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C). - contiguous(), + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C) + out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) out = self.proj_out(out) return x + out class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None): b, c, h, w = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') + x = rearrange(x, "b c h w -> b (h w) c") out = super().forward(x, context=context, mask=mask) - out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) return x + out def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", - "none"], f'attn_type {attn_type} unknown' + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": attn_type = "vanilla-xformers" if attn_type == "vanilla": @@ -245,21 +249,22 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): class Model(nn.Module): - - def __init__(self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla"): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): super().__init__() if use_linear_attn: attn_type = "linear" @@ -274,10 +279,12 @@ def __init__(self, if self.use_timestep: # timestep embedding self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ]) + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) @@ -292,10 +299,10 @@ def __init__(self, block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -309,15 +316,13 @@ def __init__(self, # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() @@ -330,10 +335,13 @@ def __init__(self, if i_block == self.num_res_blocks: skip_in = ch * in_ch_mult[i_level] block.append( - ResnetBlock(in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -343,14 +351,14 @@ def __init__(self, if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution + # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) @@ -401,23 +409,24 @@ def get_last_layer(self): class Encoder(nn.Module): - - def __init__(self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - **ignore_kwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): super().__init__() if use_linear_attn: attn_type = "linear" @@ -442,10 +451,10 @@ def __init__(self, block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -459,23 +468,19 @@ def __init__(self, # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # end self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) def forward(self, x): # timestep embedding @@ -506,24 +511,25 @@ def forward(self, x): class Decoder(nn.Module): - - def __init__(self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type="vanilla", - **ignorekwargs): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): super().__init__() if use_linear_attn: attn_type = "linear" @@ -537,9 +543,9 @@ def __init__(self, self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) + (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2**(self.num_resolutions - 1) + curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) @@ -548,15 +554,13 @@ def __init__(self, # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) # upsampling self.up = nn.ModuleList() @@ -566,10 +570,10 @@ def __init__(self, block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) @@ -579,14 +583,14 @@ def __init__(self, if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order + self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] + # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding @@ -622,17 +626,18 @@ def forward(self, z): class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() - self.model = nn.ModuleList([ - nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), - nn.Conv2d(2 * in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True) - ]) + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) # end self.norm_out = Normalize(in_channels) self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) @@ -651,7 +656,6 @@ def forward(self, x): class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0): super().__init__() # upsampling @@ -659,7 +663,7 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks block_in = in_channels - curr_res = resolution // 2**(self.num_resolutions - 1) + curr_res = resolution // 2 ** (self.num_resolutions - 1) self.res_blocks = nn.ModuleList() self.upsample_blocks = nn.ModuleList() for i_level in range(self.num_resolutions): @@ -667,10 +671,10 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): res_block.append( - ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: @@ -696,21 +700,24 @@ def forward(self, x): class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) - self.res_block1 = nn.ModuleList([ - ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) - for _ in range(depth) - ]) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ - ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) - for _ in range(depth) - ]) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) self.conv_out = nn.Conv2d( mid_channels, @@ -722,9 +729,9 @@ def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) - x = torch.nn.functional.interpolate(x, - size=(int(round(x.shape[2] * self.factor)), - int(round(x.shape[3] * self.factor)))) + x = torch.nn.functional.interpolate( + x, size=(int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor))) + ) x = self.attn(x) for block in self.res_block2: x = block(x, None) @@ -733,37 +740,42 @@ def forward(self, x): class MergedRescaleEncoder(nn.Module): - - def __init__(self, - in_channels, - ch, - resolution, - out_ch, - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - ch_mult=(1, 2, 4, 8), - rescale_factor=1.0, - rescale_module_depth=1): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, - num_res_blocks=num_res_blocks, - ch=ch, - ch_mult=ch_mult, - z_channels=intermediate_chn, - double_z=False, - resolution=resolution, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, - in_channels=intermediate_chn, - mid_channels=intermediate_chn, - out_channels=out_ch, - depth=rescale_module_depth) + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) def forward(self, x): x = self.encoder(x) @@ -772,36 +784,41 @@ def forward(self, x): class MergedRescaleDecoder(nn.Module): - - def __init__(self, - z_channels, - out_ch, - resolution, - num_res_blocks, - attn_resolutions, - ch, - ch_mult=(1, 2, 4, 8), - dropout=0.0, - resamp_with_conv=True, - rescale_factor=1.0, - rescale_module_depth=1): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): super().__init__() tmp_chn = z_channels * ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, - z_channels=tmp_chn, - attn_resolutions=attn_resolutions, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - in_channels=None, - num_res_blocks=num_res_blocks, - ch_mult=ch_mult, - resolution=resolution, - ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, - in_channels=z_channels, - mid_channels=tmp_chn, - out_channels=tmp_chn, - depth=rescale_module_depth) + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) def forward(self, x): x = self.rescaler(x) @@ -810,27 +827,27 @@ def forward(self, x): class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size num_blocks = int(np.log2(out_size // in_size)) + 1 - factor_up = 1. + (out_size % in_size) + factor_up = 1.0 + (out_size % in_size) rank_zero_info( f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" ) - self.rescaler = LatentRescaler(factor=factor_up, - in_channels=in_channels, - mid_channels=2 * in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, - resolution=out_size, - z_channels=in_channels, - num_res_blocks=2, - attn_resolutions=[], - in_channels=None, - ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) + self.rescaler = LatentRescaler( + factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, out_channels=in_channels + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) def forward(self, x): x = self.rescaler(x) @@ -839,14 +856,14 @@ def forward(self, x): class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: rank_zero_info( - f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py index cd639d936046..614fe510f20e 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -1,21 +1,20 @@ -from abc import abstractmethod import math +from abc import abstractmethod import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F - +from ldm.modules.attention import SpatialTransformer from ldm.modules.diffusionmodules.util import ( + avg_pool_nd, checkpoint, conv_nd, linear, - avg_pool_nd, - zero_module, normalization, timestep_embedding, + zero_module, ) -from ldm.modules.attention import SpatialTransformer from ldm.util import exists @@ -23,6 +22,7 @@ def convert_module_to_f16(x): pass + def convert_module_to_f32(x): pass @@ -41,7 +41,7 @@ def __init__( output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -108,25 +108,25 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x + class TransposedUpsample(nn.Module): - 'Learned 2x upsampling without padding' + "Learned 2x upsampling without padding" + def __init__(self, channels, out_channels=None, ks=5): super().__init__() self.channels = channels self.out_channels = out_channels or channels - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) - def forward(self,x): + def forward(self, x): return self.up(x) @@ -139,7 +139,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -147,9 +147,7 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) @@ -225,17 +223,13 @@ def __init__( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -246,10 +240,7 @@ def forward(self, x, emb): :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) def _forward(self, x, emb): if self.updown: @@ -311,8 +302,10 @@ def __init__( self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - #return pt_checkpoint(self._forward, x) # pytorch + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch def _forward(self, x): b, c, *spatial = x.shape @@ -339,7 +332,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -363,9 +356,7 @@ def forward(self, qkv): ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -460,10 +451,10 @@ def __init__( use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, disable_self_attentions=None, num_attention_blocks=None, @@ -472,11 +463,16 @@ def __init__( ): super().__init__() if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: context_dim = list(context_dim) @@ -484,10 +480,10 @@ def __init__( num_heads_upsample = num_heads if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + assert num_heads != -1, "Either num_heads or num_head_channels has to be set" self.image_size = image_size self.in_channels = in_channels @@ -497,19 +493,25 @@ def __init__( self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: if len(num_res_blocks) != len(channel_mult): - raise ValueError("provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult") + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") + assert all( + map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) self.attention_resolutions = attention_resolutions self.dropout = dropout @@ -540,11 +542,7 @@ def __init__( raise ValueError() self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] ) self._feature_size = model_channels input_block_chans = [model_channels] @@ -571,7 +569,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] @@ -586,10 +584,17 @@ def __init__( num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -610,9 +615,7 @@ def __init__( down=True, ) if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ) ch = out_ch @@ -626,7 +629,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( @@ -643,11 +646,18 @@ def __init__( num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ), + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + ), ResBlock( ch, time_embed_dim, @@ -682,7 +692,7 @@ def __init__( num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: - #num_heads = 1 + # num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels if exists(disable_self_attentions): disabled_sa = disable_self_attentions[level] @@ -697,10 +707,17 @@ def __init__( num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, ) ) if level and i == self.num_res_blocks[level]: @@ -730,10 +747,10 @@ def __init__( ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) def convert_to_fp16(self): """ @@ -751,7 +768,7 @@ def convert_to_fp32(self): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py index 03816662098c..82cc2157ca68 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py @@ -1,8 +1,8 @@ -import torch -import torch.nn as nn -import numpy as np from functools import partial +import numpy as np +import torch +import torch.nn as nn from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from ldm.util import default @@ -14,37 +14,41 @@ def __init__(self, noise_schedule_config=None): if noise_schedule_config is not None: self.register_schedule(**noise_schedule_config) - def register_schedule(self, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas + def register_schedule( + self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 + ): + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - timesteps, = betas.shape + (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + assert alphas_cumprod.shape[0] == self.num_timesteps, "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer("sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) def forward(self, x): return x, None @@ -76,6 +80,3 @@ def forward(self, x, noise_level=None): assert isinstance(noise_level, torch.Tensor) z = self.q_sample(x, noise_level) return z, noise_level - - - diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py index 36b4a171b6c2..aed1b061323a 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/util.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/util.py @@ -8,7 +8,6 @@ # thanks! import math -import os import numpy as np import torch @@ -19,10 +18,10 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": - betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2) + betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 elif schedule == "cosine": - timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s) + timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] @@ -32,18 +31,18 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, elif schedule == "sqrt_linear": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5 + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): - if ddim_discr_method == 'uniform': + if ddim_discr_method == "uniform": c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int) + elif ddim_discr_method == "quad": + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int) else: raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') @@ -51,7 +50,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep # add one to get the final alpha values right (the ones from first scale to data during sampling) steps_out = ddim_timesteps + 1 if verbose: - print(f'Selected timesteps for ddim sampler: {steps_out}') + print(f"Selected timesteps for ddim sampler: {steps_out}") return steps_out @@ -63,9 +62,11 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): # according the the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}") + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) return sigmas, alphas, alphas_prev @@ -106,6 +107,7 @@ def checkpoint(func, inputs, params, flag): """ if flag: from torch.utils.checkpoint import checkpoint as torch_checkpoint + return torch_checkpoint(func, *inputs) # args = tuple(inputs) + tuple(params) # return CheckpointFunction.apply(func, len(inputs), *args) @@ -114,7 +116,6 @@ def checkpoint(func, inputs, params, flag): class CheckpointFunction(torch.autograd.Function): - @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function @@ -123,7 +124,7 @@ def forward(ctx, run_function, length, *args): ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled() + "cache_enabled": torch.is_autocast_cache_enabled(), } with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) @@ -132,8 +133,7 @@ def forward(ctx, run_function, length, *args): @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. @@ -162,14 +162,15 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ if not repeat_only: half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / - half).to(device=timesteps.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: - embedding = repeat(timesteps, 'b -> b d', d=dim) + embedding = repeat(timesteps, "b -> b d", d=dim) return embedding @@ -210,13 +211,11 @@ def normalization(channels): # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): - def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): - def forward(self, x): return super().forward(x.float()).type(x.dtype) @@ -255,7 +254,6 @@ def avg_pool_nd(dims, *args, **kwargs): class HybridConditioner(nn.Module): - def __init__(self, c_concat_config, c_crossattn_config): super().__init__() self.concat_conditioner = instantiate_from_config(c_concat_config) @@ -264,7 +262,7 @@ def __init__(self, c_concat_config, c_crossattn_config): def forward(self, c_concat, c_crossattn): c_concat = self.concat_conditioner(c_concat) c_crossattn = self.crossattn_conditioner(c_crossattn) - return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} def noise_like(shape, device, repeat=False): diff --git a/examples/images/diffusion/ldm/modules/distributions/distributions.py b/examples/images/diffusion/ldm/modules/distributions/distributions.py index f2b8ef901130..b5f3b1ad48da 100644 --- a/examples/images/diffusion/ldm/modules/distributions/distributions.py +++ b/examples/images/diffusion/ldm/modules/distributions/distributions.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch class AbstractDistribution: @@ -38,25 +38,25 @@ def sample(self): def kl(self, other=None): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) - - def nll(self, sample, dims=[1,2,3]): + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: - return torch.Tensor([0.]) + return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean @@ -78,15 +78,8 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) diff --git a/examples/images/diffusion/ldm/modules/ema.py b/examples/images/diffusion/ldm/modules/ema.py index bded25019b9b..c3863269675e 100644 --- a/examples/images/diffusion/ldm/modules/ema.py +++ b/examples/images/diffusion/ldm/modules/ema.py @@ -6,17 +6,18 @@ class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: - raise ValueError('Decay must be between 0 and 1') + raise ValueError("Decay must be between 0 and 1") self.m_name2s_name = {} - self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates - else torch.tensor(-1, dtype=torch.int)) + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int) + ) for name, p in model.named_parameters(): if p.requires_grad: # remove as '.'-character is not allowed in buffers - s_name = name.replace('.', '') + s_name = name.replace(".", "") self.m_name2s_name.update({name: s_name}) self.register_buffer(s_name, p.clone().detach().data) @@ -24,7 +25,7 @@ def __init__(self, model, decay=0.9999, use_num_upates=True): def reset_num_updates(self): del self.num_updates - self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) def forward(self, model): decay = self.decay diff --git a/examples/images/diffusion/ldm/modules/encoders/modules.py b/examples/images/diffusion/ldm/modules/encoders/modules.py index 4edd5496b9e6..58bff2382c47 100644 --- a/examples/images/diffusion/ldm/modules/encoders/modules.py +++ b/examples/images/diffusion/ldm/modules/encoders/modules.py @@ -1,11 +1,9 @@ +import open_clip import torch import torch.nn as nn +from ldm.util import count_params from torch.utils.checkpoint import checkpoint - -from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel - -import open_clip -from ldm.util import default, count_params +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer class AbstractEncoder(nn.Module): @@ -17,13 +15,12 @@ def encode(self, *args, **kwargs): class IdentityEncoder(AbstractEncoder): - def encode(self, x): return x class ClassEmbedder(nn.Module): - def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): super().__init__() self.key = key self.embedding = nn.Embedding(n_classes, embed_dim) @@ -35,9 +32,9 @@ def forward(self, batch, key=None, disable_dropout=False): key = self.key # this is for use in crossattn c = batch[key][:, None] - if self.ucg_rate > 0. and not disable_dropout: - mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) - c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + if self.ucg_rate > 0.0 and not disable_dropout: + mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) c = c.long() c = self.embedding(c) return c @@ -57,24 +54,34 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? if freeze: self.freeze() def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) @@ -87,13 +94,18 @@ def encode(self, text): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" - LAYERS = [ - "last", - "pooled", - "hidden" - ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, - freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + ): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) @@ -110,15 +122,22 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_l def freeze(self): self.transformer = self.transformer.eval() - #self.train = disabled_train + # self.train = disabled_train for param in self.parameters(): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": @@ -135,16 +154,19 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): """ Uses the OpenCLIP transformer encoder for text """ + LAYERS = [ - #"pooled", + # "pooled", "last", - "penultimate" + "penultimate", ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, - freeze=True, layer="last"): + + def __init__( + self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, freeze=True, layer="last" + ): super().__init__() assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device("cpu"), pretrained=version) del model.visual self.model = model @@ -179,7 +201,7 @@ def encode_with_transformer(self, text): x = self.model.ln_final(x) return x - def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break @@ -194,13 +216,21 @@ def encode(self, text): class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", - clip_max_length=77, t5_max_length=77): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." + ) def encode(self, text): return self(text) @@ -209,5 +239,3 @@ def forward(self, text): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] - - diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py index 32ef56169978..879b2aa099b6 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py @@ -10,33 +10,32 @@ # -------------------------------------------- """ -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -54,7 +53,7 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -63,7 +62,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -74,7 +73,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1.0, 0.0])) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -126,13 +125,13 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) @@ -142,8 +141,8 @@ def blur(x, k): return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -157,8 +156,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] @@ -208,13 +206,13 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': + """ + if filter_type == "gaussian": return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': + if filter_type == "laplacian": return fspecial_laplacian(*args, **kwargs) @@ -226,19 +224,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -253,14 +251,14 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -275,22 +273,22 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + """ + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] @@ -314,7 +312,7 @@ def add_sharpening(img, weight=0.5, radius=50, threshold=10): blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') + mask = mask.astype("float32") soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual @@ -330,8 +328,8 @@ def add_blur(img, sf=4): l2 = wd2 * random.random() k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) else: - k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial("gaussian", 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode="mirror") return img @@ -366,6 +364,7 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() @@ -374,11 +373,11 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -392,23 +391,23 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) @@ -418,7 +417,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(30, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -428,10 +427,10 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :] return lq, hq @@ -452,18 +451,19 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]) + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -475,7 +475,6 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: img = add_blur(img, sf=sf) @@ -487,13 +486,16 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -541,18 +543,20 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): """ image = util.uint2single(image) isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] - hq = image.copy() + image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -564,7 +568,6 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: image = add_blur(image, sf=sf) @@ -576,13 +579,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -609,7 +615,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {"image":image} + example = {"image": image} return example @@ -630,11 +636,11 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc """ h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") if use_sharp: img = add_sharpening(img) @@ -686,11 +692,12 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) else: - print('check the shuffle!') + print("check the shuffle!") # resize to desired size - img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), interpolation=random.choice([1, 2, 3]) + ) # add final JPEG compression noise img = add_JPEG_noise(img) @@ -701,30 +708,30 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc return img, hq -if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - print(img) - img = util.uint2single(img) - print(img) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_lq = deg_fn(img) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') - - +if __name__ == "__main__": + print("hey") + img = util.imread_uint("utils/test.png", 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize( + util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + ".png") diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py index 808c7f882cb7..cf3f83f0c011 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util """ # -------------------------------------------- @@ -25,17 +24,18 @@ # -------------------------------------------- """ + def modcrop_np(img, sf): - ''' + """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image - ''' + """ w, h = img.shape[:2] im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] + return im[: w - w % sf, : h - h % sf, ...] """ @@ -53,7 +53,7 @@ def analytic_kernel(k): # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] @@ -62,7 +62,7 @@ def analytic_kernel(k): def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel + """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range @@ -73,7 +73,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): k : kernel """ - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1.0, 0.0])) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) @@ -125,13 +125,13 @@ def shift_pixel(x, sf, upper_left=True): def blur(x, k): - ''' + """ x: image, NxcxHxW k: kernel, Nx1xhxw - ''' + """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode="replicate") k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) @@ -141,8 +141,8 @@ def blur(x, k): return x -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0): + """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var @@ -156,8 +156,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] @@ -207,13 +206,13 @@ def fspecial_laplacian(alpha): def fspecial(filter_type, *args, **kwargs): - ''' + """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': + """ + if filter_type == "gaussian": return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': + if filter_type == "laplacian": return fspecial_laplacian(*args, **kwargs) @@ -225,19 +224,19 @@ def fspecial(filter_type, *args, **kwargs): def bicubic_degradation(x, sf=3): - ''' + """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image - ''' + """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling + """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double @@ -252,14 +251,14 @@ def srmd_degradation(x, k, sf=3): pages={3262--3271}, year={2018} } - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + """ + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur + """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double @@ -274,22 +273,22 @@ def dpsr_degradation(x, k, sf=3): pages={1671--1681}, year={2019} } - ''' + """ x = bicubic_degradation(x, sf=sf) - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") return x def classical_degradation(x, k, sf=3): - ''' blur + downsampling + """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image - ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + """ + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] @@ -313,7 +312,7 @@ def add_sharpening(img, weight=0.5, radius=50, threshold=10): blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') + mask = mask.astype("float32") soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual @@ -325,16 +324,16 @@ def add_blur(img, sf=4): wd2 = 4.0 + sf wd = 2.0 + 0.2 * sf - wd2 = wd2/4 - wd = wd/4 + wd2 = wd2 / 4 + wd = wd / 4 if random.random() < 0.5: l1 = wd2 * random.random() l2 = wd2 * random.random() k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) else: - k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) - img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + k = fspecial("gaussian", random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode="mirror") return img @@ -369,6 +368,7 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() @@ -377,11 +377,11 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -395,23 +395,23 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): elif rnum < 0.4: img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: - L = noise_level2 / 255. + L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. + img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) @@ -421,7 +421,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(80, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -431,10 +431,10 @@ def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + hq = hq[rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :] return lq, hq @@ -455,18 +455,19 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') + raise ValueError(f"img size ({h1}X{w1}) is too small!") hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]) + ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) @@ -478,7 +479,6 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: img = add_blur(img, sf=sf) @@ -490,13 +490,16 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) + img = cv2.resize( + img, + (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -544,18 +547,20 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): """ image = util.uint2single(image) isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] - hq = image.copy() + image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) @@ -567,7 +572,6 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: - if i == 0: image = add_blur(image, sf=sf) @@ -582,13 +586,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): # downsample2 if random.random() < 0.8: sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) + image = cv2.resize( + image, + (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3]), + ) else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -617,16 +624,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): image = add_JPEG_noise(image) image = util.single2uint(image) if up: - image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + image = cv2.resize( + image, (w1, h1), interpolation=cv2.INTER_CUBIC + ) # todo: random, as above? want to condition on it then example = {"image": image} return example - - -if __name__ == '__main__': +if __name__ == "__main__": print("hey") - img = util.imread_uint('utils/test.png', 3) + img = util.imread_uint("utils/test.png", 3) img = img[:448, :448] h = img.shape[0] // 4 print("resizing to", h) @@ -638,14 +645,17 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): img_lq = deg_fn(img)["image"] img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[ + "image" + ] print(img_lq.shape) print("bicubic", img_lq_bicubic.shape) print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), - (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) + lq_nearest = cv2.resize( + util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) + lq_bicubic_nearest = cv2.resize( + util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0 + ) img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') + util.imsave(img_concat, str(i) + ".png") diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py index 0175f155ad90..71fae1084b61 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py @@ -1,18 +1,20 @@ -import os import math +import os import random +from datetime import datetime + +import cv2 import numpy as np import torch -import cv2 from torchvision.utils import make_grid -from datetime import datetime -#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + +# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" -''' +""" # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 @@ -20,10 +22,10 @@ # https://github.com/twhui/SRGAN-pyTorch # https://github.com/xinntao/BasicSR # -------------------------------------------- -''' +""" -IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] +IMG_EXTENSIONS = [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP", ".tif"] def is_image_file(filename): @@ -31,12 +33,12 @@ def is_image_file(filename): def get_timestamp(): - return datetime.now().strftime('%y%m%d-%H%M%S') + return datetime.now().strftime("%y%m%d-%H%M%S") def imshow(x, title=None, cbar=False, figsize=None): plt.figure(figsize=figsize) - plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + plt.imshow(np.squeeze(x), interpolation="nearest", cmap="gray") if title: plt.title(title) if cbar: @@ -44,24 +46,24 @@ def imshow(x, title=None, cbar=False, figsize=None): plt.show() -def surf(Z, cmap='rainbow', figsize=None): +def surf(Z, cmap="rainbow", figsize=None): plt.figure(figsize=figsize) - ax3 = plt.axes(projection='3d') + ax3 = plt.axes(projection="3d") w, h = Z.shape[:2] - xx = np.arange(0,w,1) - yy = np.arange(0,h,1) + xx = np.arange(0, w, 1) + yy = np.arange(0, h, 1) X, Y = np.meshgrid(xx, yy) - ax3.plot_surface(X,Y,Z,cmap=cmap) - #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + ax3.plot_surface(X, Y, Z, cmap=cmap) + # ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) plt.show() -''' +""" # -------------------------------------------- # get image pathes # -------------------------------------------- -''' +""" def get_image_paths(dataroot): @@ -72,37 +74,37 @@ def get_image_paths(dataroot): def _get_paths_from_images(path): - assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + assert os.path.isdir(path), "{:s} is not a valid directory".format(path) images = [] for dirpath, _, fnames in sorted(os.walk(path)): for fname in sorted(fnames): if is_image_file(fname): img_path = os.path.join(dirpath, fname) images.append(img_path) - assert images, '{:s} has no valid image file'.format(path) + assert images, "{:s} has no valid image file".format(path) return images -''' +""" # -------------------------------------------- -# split large images into small images +# split large images into small images # -------------------------------------------- -''' +""" def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): w, h = img.shape[:2] patches = [] if w > p_max and h > p_max: - w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) - h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) - w1.append(w-p_size) - h1.append(h-p_size) -# print(w1) -# print(h1) + w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int)) + w1.append(w - p_size) + h1.append(h - p_size) + # print(w1) + # print(h1) for i in w1: for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) + patches.append(img[i : i + p_size, j : j + p_size, :]) else: patches.append(img) @@ -118,7 +120,7 @@ def imssave(imgs, img_path): for i, img in enumerate(imgs): if img.ndim == 3: img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + new_path = os.path.join(os.path.dirname(img_path), img_name + str("_s{:04d}".format(i)) + ".png") cv2.imwrite(new_path, img) @@ -139,15 +141,16 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, # img_name, ext = os.path.splitext(os.path.basename(img_path)) img = imread_uint(img_path, n_channels=n_channels) patches = patches_from_image(img, p_size, p_overlap, p_max) - imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) - #if original_dataroot == taget_dataroot: - #del img_path + imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) + # if original_dataroot == taget_dataroot: + # del img_path + -''' +""" # -------------------------------------------- # makedir # -------------------------------------------- -''' +""" def mkdir(path): @@ -165,18 +168,18 @@ def mkdirs(paths): def mkdir_and_rename(path): if os.path.exists(path): - new_name = path + '_archived_' + get_timestamp() - print('Path already exists. Rename it to [{:s}]'.format(new_name)) + new_name = path + "_archived_" + get_timestamp() + print("Path already exists. Rename it to [{:s}]".format(new_name)) os.rename(path, new_name) os.makedirs(path) -''' +""" # -------------------------------------------- # read image from path # opencv is fast, but read BGR numpy image # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -206,6 +209,7 @@ def imsave(img, img_path): img = img[:, :, [2, 1, 0]] cv2.imwrite(img_path, img) + def imwrite(img, img_path): img = np.squeeze(img) if img.ndim == 3: @@ -213,7 +217,6 @@ def imwrite(img, img_path): cv2.imwrite(img_path, img) - # -------------------------------------------- # get single image of size HxWxn_channles (BGR) # -------------------------------------------- @@ -221,7 +224,7 @@ def read_img(path): # read image by cv2 # return: Numpy float32, HWC, BGR, [0,1] img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE - img = img.astype(np.float32) / 255. + img = img.astype(np.float32) / 255.0 if img.ndim == 2: img = np.expand_dims(img, axis=2) # some images have 4 channels @@ -230,7 +233,7 @@ def read_img(path): return img -''' +""" # -------------------------------------------- # image format conversion # -------------------------------------------- @@ -238,7 +241,7 @@ def read_img(path): # numpy(single) <---> tensor # numpy(unit) <---> tensor # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -247,23 +250,19 @@ def read_img(path): def uint2single(img): - - return np.float32(img/255.) + return np.float32(img / 255.0) def single2uint(img): - - return np.uint8((img.clip(0, 1)*255.).round()) + return np.uint8((img.clip(0, 1) * 255.0).round()) def uint162single(img): - - return np.float32(img/65535.) + return np.float32(img / 65535.0) def single2uint16(img): - - return np.uint16((img.clip(0, 1)*65535.).round()) + return np.uint16((img.clip(0, 1) * 65535.0).round()) # -------------------------------------------- @@ -275,14 +274,14 @@ def single2uint16(img): def uint2tensor4(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0) # convert uint to 3-dimensional torch tensor def uint2tensor3(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0) # convert 2/3/4-dimensional torch tensor to uint @@ -290,7 +289,7 @@ def tensor2uint(img): img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() if img.ndim == 3: img = np.transpose(img, (1, 2, 0)) - return np.uint8((img*255.0).round()) + return np.uint8((img * 255.0).round()) # -------------------------------------------- @@ -316,6 +315,7 @@ def tensor2single(img): return img + # convert torch tensor to single def tensor2single3(img): img = img.data.squeeze().float().cpu().numpy() @@ -340,11 +340,11 @@ def single42tensor4(img): # from skimage.io import imread, imsave def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - ''' + """ Converts a torch Tensor into an image Numpy array of BGR channel order Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) - ''' + """ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] n_dim = tensor.dim() @@ -358,15 +358,14 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): elif n_dim == 2: img_np = tensor.numpy() else: - raise TypeError( - 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim)) if out_type == np.uint8: img_np = (img_np * 255.0).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. return img_np.astype(out_type) -''' +""" # -------------------------------------------- # Augmentation, flipe and/or rotate # -------------------------------------------- @@ -374,12 +373,11 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): # (1) augmet_img: numpy image of WxHxC or WxH # (2) augment_img_tensor4: tensor image 1xCxWxH # -------------------------------------------- -''' +""" def augment_img(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -399,8 +397,7 @@ def augment_img(img, mode=0): def augment_img_tensor4(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: @@ -420,8 +417,7 @@ def augment_img_tensor4(img, mode=0): def augment_img_tensor(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' + """Kai Zhang (github: https://github.com/cszn)""" img_size = img.size() img_np = img.data.cpu().numpy() if len(img_size) == 3: @@ -484,11 +480,11 @@ def _augment(img): return [_augment(img) for img in img_list] -''' +""" # -------------------------------------------- # modcrop and shave # -------------------------------------------- -''' +""" def modcrop(img_in, scale): @@ -497,13 +493,13 @@ def modcrop(img_in, scale): if img.ndim == 2: H, W = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r] + img = img[: H - H_r, : W - W_r] elif img.ndim == 3: H, W, C = img.shape H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r, :] + img = img[: H - H_r, : W - W_r, :] else: - raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim)) return img @@ -511,11 +507,11 @@ def shave(img_in, border=0): # img_in: Numpy, HWC or HW img = np.copy(img_in) h, w = img.shape[:2] - img = img[border:h-border, border:w-border] + img = img[border : h - border, border : w - border] return img -''' +""" # -------------------------------------------- # image processing process on numpy image # channel_convert(in_c, tar_type, img_list): @@ -523,96 +519,99 @@ def shave(img_in, border=0): # bgr2ycbcr(img, only_y=True): # ycbcr2rgb(img): # -------------------------------------------- -''' +""" def rgb2ycbcr(img, only_y=True): - '''same as matlab rgb2ycbcr + """same as matlab rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]] + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def ycbcr2rgb(img): - '''same as matlab ycbcr2rgb + """same as matlab ycbcr2rgb Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert - rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], - [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + rlt = np.matmul( + img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]] + ) * 255.0 + [-222.921, 135.576, -276.836] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def bgr2ycbcr(img, only_y=True): - '''bgr version of rgb2ycbcr + """bgr version of rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] - ''' + """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: - img *= 255. + img *= 255.0 # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + rlt = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]] + ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: - rlt /= 255. + rlt /= 255.0 return rlt.astype(in_img_type) def channel_convert(in_c, tar_type, img_list): # conversion among BGR, gray and y - if in_c == 3 and tar_type == 'gray': # BGR to gray + if in_c == 3 and tar_type == "gray": # BGR to gray gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] return [np.expand_dims(img, axis=2) for img in gray_list] - elif in_c == 3 and tar_type == 'y': # BGR to y + elif in_c == 3 and tar_type == "y": # BGR to y y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] return [np.expand_dims(img, axis=2) for img in y_list] - elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + elif in_c == 1 and tar_type == "RGB": # gray/y to BGR return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] else: return img_list -''' +""" # -------------------------------------------- # metric, PSNR and SSIM # -------------------------------------------- -''' +""" # -------------------------------------------- @@ -620,19 +619,19 @@ def channel_convert(in_c, tar_type, img_list): # -------------------------------------------- def calculate_psnr(img1, img2, border=0): # img1 and img2 have range [0, 255] - #img1 = img1.squeeze() - #img2 = img2.squeeze() + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') + raise ValueError("Input images must have the same dimensions.") h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) - mse = np.mean((img1 - img2)**2) + mse = np.mean((img1 - img2) ** 2) if mse == 0: - return float('inf') + return float("inf") return 20 * math.log10(255.0 / math.sqrt(mse)) @@ -640,17 +639,17 @@ def calculate_psnr(img1, img2, border=0): # SSIM # -------------------------------------------- def calculate_ssim(img1, img2, border=0): - '''calculate SSIM + """calculate SSIM the same outputs as MATLAB's img1, img2: [0, 255] - ''' - #img1 = img1.squeeze() - #img2 = img2.squeeze() + """ + # img1 = img1.squeeze() + # img2 = img2.squeeze() if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') + raise ValueError("Input images must have the same dimensions.") h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border : h - border, border : w - border] + img2 = img2[border : h - border, border : w - border] if img1.ndim == 2: return ssim(img1, img2) @@ -658,17 +657,17 @@ def calculate_ssim(img1, img2, border=0): if img1.shape[2] == 3: ssims = [] for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) else: - raise ValueError('Wrong input image dimensions.') + raise ValueError("Wrong input image dimensions.") def ssim(img1, img2): - C1 = (0.01 * 255)**2 - C2 = (0.03 * 255)**2 + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) @@ -684,16 +683,15 @@ def ssim(img1, img2): sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim_map.mean() -''' +""" # -------------------------------------------- # matlab's bicubic imresize (numpy and torch) [0, 1] # -------------------------------------------- -''' +""" # matlab 'imresize' function, now only support 'bicubic' @@ -701,8 +699,9 @@ def cubic(x): absx = torch.abs(x) absx2 = absx**2 absx3 = absx**3 - return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ - (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 + ) * (((absx > 1) * (absx <= 2)).type_as(absx)) def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): @@ -729,8 +728,9 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( - 1, P).expand(out_length, P) + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand( + out_length, P + ) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. @@ -773,7 +773,7 @@ def imresize(img, scale, antialiasing=True): in_C, in_H, in_W = img.size() out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 - kernel = 'cubic' + kernel = "cubic" # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the @@ -782,9 +782,11 @@ def imresize(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) @@ -805,7 +807,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -827,7 +829,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_W[i]) if need_squeeze: out_2.squeeze_() return out_2 @@ -848,7 +850,7 @@ def imresize_np(img, scale, antialiasing=True): in_H, in_W, in_C = img.size() out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 - kernel = 'cubic' + kernel = "cubic" # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the @@ -857,9 +859,11 @@ def imresize_np(img, scale, antialiasing=True): # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) @@ -880,7 +884,7 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -902,15 +906,15 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): - out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv(weights_W[i]) if need_squeeze: out_2.squeeze_() return out_2.numpy() -if __name__ == '__main__': - print('---') +if __name__ == "__main__": + print("---") # img = imread_uint('test.bmp', 3) # img = uint2single(img) -# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file +# img_bicubic = imresize_np(img, 1/4) diff --git a/examples/images/diffusion/ldm/modules/midas/api.py b/examples/images/diffusion/ldm/modules/midas/api.py index b58ebbffd942..6619f515fa0e 100644 --- a/examples/images/diffusion/ldm/modules/midas/api.py +++ b/examples/images/diffusion/ldm/modules/midas/api.py @@ -3,13 +3,11 @@ import cv2 import torch import torch.nn as nn -from torchvision.transforms import Compose - from ldm.modules.midas.midas.dpt_depth import DPTDepthModel from ldm.modules.midas.midas.midas_net import MidasNet from ldm.modules.midas.midas.midas_net_custom import MidasNet_small -from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet - +from ldm.modules.midas.midas.transforms import NormalizeImage, PrepareForNet, Resize +from torchvision.transforms import Compose ISL_PATHS = { "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", @@ -98,18 +96,20 @@ def load_model(model_type): model = MidasNet(model_path, non_negative=True) net_w, net_h = 384, 384 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) elif model_type == "midas_v21_small": - model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, - non_negative=True, blocks={'expand': True}) + model = MidasNet_small( + model_path, + features=64, + backbone="efficientnet_lite3", + exportable=True, + non_negative=True, + blocks={"expand": True}, + ) net_w, net_h = 256, 256 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) else: print(f"model_type '{model_type}' not implemented, use: --model_type large") @@ -135,11 +135,7 @@ def load_model(model_type): class MiDaSInference(nn.Module): - MODEL_TYPES_TORCH_HUB = [ - "DPT_Large", - "DPT_Hybrid", - "MiDaS_small" - ] + MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"] MODEL_TYPES_ISL = [ "dpt_large", "dpt_hybrid", @@ -149,7 +145,7 @@ class MiDaSInference(nn.Module): def __init__(self, model_type): super().__init__() - assert (model_type in self.MODEL_TYPES_ISL) + assert model_type in self.MODEL_TYPES_ISL model, _ = load_model(model_type) self.model = model self.model.train = disabled_train @@ -167,4 +163,3 @@ def forward(self, x): ) assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) return prediction - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py index 5cf430239b47..5c2e0e93b049 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py @@ -8,7 +8,7 @@ def load(self, path): Args: path (str): file path """ - parameters = torch.load(path, map_location=torch.device('cpu')) + parameters = torch.load(path, map_location=torch.device("cpu")) if "optimizer" in parameters: parameters = parameters["model"] diff --git a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py index 2145d18fa980..154de57cd2e8 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py @@ -1,18 +1,22 @@ import torch import torch.nn as nn -from .vit import ( - _make_pretrained_vitb_rn50_384, - _make_pretrained_vitl16_384, - _make_pretrained_vitb16_384, - forward_vit, -) - -def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): +from .vit import _make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384, _make_pretrained_vitl16_384 + + +def _make_encoder( + backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout="ignore", +): if backbone == "vitl16_384": - pretrained = _make_pretrained_vitl16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) + pretrained = _make_pretrained_vitl16_384(use_pretrained, hooks=hooks, use_readout=use_readout) scratch = _make_scratch( [256, 512, 1024, 1024], features, groups=groups, expand=expand ) # ViT-L/16 - 85.0% Top1 (backbone) @@ -27,22 +31,20 @@ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, ex [256, 512, 768, 768], features, groups=groups, expand=expand ) # ViT-H/16 - 85.0% Top1 (backbone) elif backbone == "vitb16_384": - pretrained = _make_pretrained_vitb16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) + pretrained = _make_pretrained_vitb16_384(use_pretrained, hooks=hooks, use_readout=use_readout) scratch = _make_scratch( [96, 192, 384, 768], features, groups=groups, expand=expand ) # ViT-B/16 - 84.6% Top1 (backbone) elif backbone == "resnext101_wsl": pretrained = _make_pretrained_resnext101_wsl(use_pretrained) - scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 elif backbone == "efficientnet_lite3": pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) - scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 else: print(f"Backbone '{backbone}' not implemented") assert False - + return pretrained, scratch @@ -53,11 +55,11 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False): out_shape2 = out_shape out_shape3 = out_shape out_shape4 = out_shape - if expand==True: + if expand == True: out_shape1 = out_shape - out_shape2 = out_shape*2 - out_shape3 = out_shape*4 - out_shape4 = out_shape*8 + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 scratch.layer1_rn = nn.Conv2d( in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups @@ -77,10 +79,7 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False): def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): efficientnet = torch.hub.load( - "rwightman/gen-efficientnet-pytorch", - "tf_efficientnet_lite3", - pretrained=use_pretrained, - exportable=exportable + "rwightman/gen-efficientnet-pytorch", "tf_efficientnet_lite3", pretrained=use_pretrained, exportable=exportable ) return _make_efficientnet_backbone(efficientnet) @@ -88,21 +87,17 @@ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): def _make_efficientnet_backbone(effnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] - ) + pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]) pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) return pretrained - + def _make_resnet_backbone(resnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 - ) + pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1) pretrained.layer2 = resnet.layer2 pretrained.layer3 = resnet.layer3 @@ -116,10 +111,8 @@ def _make_pretrained_resnext101_wsl(use_pretrained): return _make_resnet_backbone(resnet) - class Interpolate(nn.Module): - """Interpolation module. - """ + """Interpolation module.""" def __init__(self, scale_factor, mode, align_corners=False): """Init. @@ -145,16 +138,13 @@ def forward(self, x): tensor: interpolated data """ - x = self.interp( - x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners - ) + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x class ResidualConvUnit(nn.Module): - """Residual convolution module. - """ + """Residual convolution module.""" def __init__(self, features): """Init. @@ -164,13 +154,9 @@ def __init__(self, features): """ super().__init__() - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) self.relu = nn.ReLU(inplace=True) @@ -192,8 +178,7 @@ def forward(self, x): class FeatureFusionBlock(nn.Module): - """Feature fusion block. - """ + """Feature fusion block.""" def __init__(self, features): """Init. @@ -219,18 +204,13 @@ def forward(self, *xs): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=True - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=True) return output - - class ResidualConvUnit_custom(nn.Module): - """Residual convolution module. - """ + """Residual convolution module.""" def __init__(self, features, activation, bn): """Init. @@ -242,17 +222,13 @@ def __init__(self, features, activation, bn): self.bn = bn - self.groups=1 + self.groups = 1 - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) - - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) - if self.bn==True: + if self.bn == True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) @@ -269,15 +245,15 @@ def forward(self, x): Returns: tensor: output """ - + out = self.activation(x) out = self.conv1(out) - if self.bn==True: + if self.bn == True: out = self.bn1(out) - + out = self.activation(out) out = self.conv2(out) - if self.bn==True: + if self.bn == True: out = self.bn2(out) if self.groups > 1: @@ -289,8 +265,7 @@ def forward(self, x): class FeatureFusionBlock_custom(nn.Module): - """Feature fusion block. - """ + """Feature fusion block.""" def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): """Init. @@ -303,18 +278,18 @@ def __init__(self, features, activation, deconv=False, bn=False, expand=False, a self.deconv = deconv self.align_corners = align_corners - self.groups=1 + self.groups = 1 self.expand = expand out_features = features - if self.expand==True: - out_features = features//2 - + if self.expand == True: + out_features = features // 2 + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) - + self.skip_add = nn.quantized.FloatFunctional() def forward(self, *xs): @@ -332,11 +307,8 @@ def forward(self, *xs): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=self.align_corners - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=self.align_corners) output = self.out_conv(output) return output - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py index 4e9aab5d2767..74871e8b1fce 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py @@ -1,15 +1,8 @@ import torch import torch.nn as nn -import torch.nn.functional as F from .base_model import BaseModel -from .blocks import ( - FeatureFusionBlock, - FeatureFusionBlock_custom, - Interpolate, - _make_encoder, - forward_vit, -) +from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder, forward_vit def _make_fusion_block(features, use_bn): @@ -33,7 +26,6 @@ def __init__( channels_last=False, use_bn=False, ): - super(DPT, self).__init__() self.channels_last = channels_last @@ -48,7 +40,7 @@ def __init__( self.pretrained, self.scratch = _make_encoder( backbone, features, - False, # Set to true of you want to train from scratch, uses ImageNet weights + False, # Set to true of you want to train from scratch, uses ImageNet weights groups=1, expand=False, exportable=False, @@ -63,7 +55,6 @@ def __init__( self.scratch.output_conv = head - def forward(self, x): if self.channels_last == True: x.contiguous(memory_format=torch.channels_last) @@ -102,8 +93,7 @@ def __init__(self, path=None, non_negative=True, **kwargs): super().__init__(head, **kwargs) if path is not None: - self.load(path) + self.load(path) def forward(self, x): return super().forward(x).squeeze(dim=1) - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py index 8a954977800b..0dd87b59619c 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -10,8 +10,7 @@ class MidasNet(BaseModel): - """Network for monocular depth estimation. - """ + """Network for monocular depth estimation.""" def __init__(self, path=None, features=256, non_negative=True): """Init. @@ -27,7 +26,9 @@ def __init__(self, path=None, features=256, non_negative=True): use_pretrained = False if path is None else True - self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + self.pretrained, self.scratch = _make_encoder( + backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained + ) self.scratch.refinenet4 = FeatureFusionBlock(features) self.scratch.refinenet3 = FeatureFusionBlock(features) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py index 50e4acb5e53d..4d30744c46d3 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -6,15 +6,23 @@ import torch.nn as nn from .base_model import BaseModel -from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder +from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder class MidasNet_small(BaseModel): - """Network for monocular depth estimation. - """ - - def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, - blocks={'expand': True}): + """Network for monocular depth estimation.""" + + def __init__( + self, + path=None, + features=64, + backbone="efficientnet_lite3", + non_negative=True, + exportable=True, + channels_last=False, + align_corners=True, + blocks={"expand": True}, + ): """Init. Args: @@ -27,49 +35,57 @@ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_ne super(MidasNet_small, self).__init__() use_pretrained = False if path else True - + self.channels_last = channels_last self.blocks = blocks self.backbone = backbone self.groups = 1 - features1=features - features2=features - features3=features - features4=features + features1 = features + features2 = features + features3 = features + features4 = features self.expand = False - if "expand" in self.blocks and self.blocks['expand'] == True: + if "expand" in self.blocks and self.blocks["expand"] == True: self.expand = True - features1=features - features2=features*2 - features3=features*4 - features4=features*8 + features1 = features + features2 = features * 2 + features3 = features * 4 + features4 = features * 8 + + self.pretrained, self.scratch = _make_encoder( + self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable + ) - self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) - - self.scratch.activation = nn.ReLU(False) + self.scratch.activation = nn.ReLU(False) - self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + self.scratch.refinenet4 = FeatureFusionBlock_custom( + features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet3 = FeatureFusionBlock_custom( + features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet2 = FeatureFusionBlock_custom( + features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners + ) + self.scratch.refinenet1 = FeatureFusionBlock_custom( + features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners + ) - self.scratch.output_conv = nn.Sequential( - nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1, groups=self.groups), Interpolate(scale_factor=2, mode="bilinear"), - nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), self.scratch.activation, nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) - + if path: self.load(path) - def forward(self, x): """Forward pass. @@ -79,38 +95,35 @@ def forward(self, x): Returns: tensor: depth """ - if self.channels_last==True: + if self.channels_last == True: print("self.channels_last = ", self.channels_last) x.contiguous(memory_format=torch.channels_last) - layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) - + layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) - path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - + out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) - def fuse_model(m): prev_previous_type = nn.Identity() - prev_previous_name = '' + prev_previous_name = "" previous_type = nn.Identity() - previous_name = '' + previous_name = "" for name, module in m.named_modules(): if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: # print("FUSED ", prev_previous_name, previous_name, name) @@ -125,4 +138,4 @@ def fuse_model(m): prev_previous_type = previous_type prev_previous_name = previous_name previous_type = type(module) - previous_name = name \ No newline at end of file + previous_name = name diff --git a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py index 350cbc116626..aede0fa0c73f 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py @@ -1,7 +1,8 @@ -import numpy as np -import cv2 import math +import cv2 +import numpy as np + def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): """Rezise the sample to ensure the given size. Keeps aspect ratio. @@ -28,13 +29,9 @@ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): shape[1] = math.ceil(scale * shape[1]) # resize - sample["image"] = cv2.resize( - sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method - ) + sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method) - sample["disparity"] = cv2.resize( - sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST - ) + sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), tuple(shape[::-1]), @@ -46,8 +43,7 @@ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): class Resize(object): - """Resize sample to given size (width, height). - """ + """Resize sample to given size (width, height).""" def __init__( self, @@ -133,24 +129,14 @@ def get_size(self, width, height): # fit height scale_width = scale_height else: - raise ValueError( - f"resize_method {self.__resize_method} not implemented" - ) + raise ValueError(f"resize_method {self.__resize_method} not implemented") if self.__resize_method == "lower_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, min_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, min_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) elif self.__resize_method == "upper_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, max_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, max_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) @@ -160,9 +146,7 @@ def get_size(self, width, height): return (new_width, new_height) def __call__(self, sample): - width, height = self.get_size( - sample["image"].shape[1], sample["image"].shape[0] - ) + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) # resize sample sample["image"] = cv2.resize( @@ -180,9 +164,7 @@ def __call__(self, sample): ) if "depth" in sample: - sample["depth"] = cv2.resize( - sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST - ) + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), @@ -195,8 +177,7 @@ def __call__(self, sample): class NormalizeImage(object): - """Normlize image by given mean and std. - """ + """Normlize image by given mean and std.""" def __init__(self, mean, std): self.__mean = mean @@ -209,8 +190,7 @@ def __call__(self, sample): class PrepareForNet(object): - """Prepare sample for usage as network input. - """ + """Prepare sample for usage as network input.""" def __init__(self): pass diff --git a/examples/images/diffusion/ldm/modules/midas/midas/vit.py b/examples/images/diffusion/ldm/modules/midas/midas/vit.py index ea46b1be88b2..41bdb566fd4f 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/vit.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/vit.py @@ -1,8 +1,9 @@ +import math +import types + +import timm import torch import torch.nn as nn -import timm -import types -import math import torch.nn.functional as F @@ -56,7 +57,7 @@ def forward(self, x): def forward_vit(pretrained, x): b, c, h, w = x.shape - glob = pretrained.model.forward_flex(x) + pretrained.model.forward_flex(x) layer_1 = pretrained.activations["1"] layer_2 = pretrained.activations["2"] @@ -117,9 +118,7 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w): def forward_flex(self, x): b, c, h, w = x.shape - pos_embed = self._resize_pos_embed( - self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] - ) + pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]) B = x.shape[0] @@ -131,15 +130,11 @@ def forward_flex(self, x): x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) if getattr(self, "dist_token", None) is not None: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) else: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + pos_embed @@ -169,13 +164,9 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1): elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) elif use_readout == "project": - readout_oper = [ - ProjectReadout(vit_features, start_index) for out_feat in features - ] + readout_oper = [ProjectReadout(vit_features, start_index) for out_feat in features] else: - assert ( - False - ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + assert False, "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" return readout_oper @@ -287,9 +278,7 @@ def _make_vit_b16_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained @@ -311,24 +300,18 @@ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): - model = timm.create_model( - "vit_deit_base_distilled_patch16_384", pretrained=pretrained - ) + model = timm.create_model("vit_deit_base_distilled_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks return _make_vit_b16_backbone( @@ -358,12 +341,8 @@ def _make_vit_b_rn50_backbone( pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) else: - pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( - get_activation("1") - ) - pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( - get_activation("2") - ) + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(get_activation("1")) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(get_activation("2")) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) @@ -419,12 +398,8 @@ def _make_vit_b_rn50_backbone( ), ) else: - pretrained.act_postprocess1 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) - pretrained.act_postprocess2 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) + pretrained.act_postprocess1 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + pretrained.act_postprocess2 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], @@ -468,16 +443,12 @@ def _make_vit_b_rn50_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained -def _make_pretrained_vitb_rn50_384( - pretrained, use_readout="ignore", hooks=None, use_vit_only=False -): +def _make_pretrained_vitb_rn50_384(pretrained, use_readout="ignore", hooks=None, use_vit_only=False): model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) hooks = [0, 1, 8, 11] if hooks == None else hooks diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py index 9a9d3b5b6637..1428d42b2445 100644 --- a/examples/images/diffusion/ldm/modules/midas/utils.py +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -1,8 +1,9 @@ """Utils for monoDepth.""" -import sys import re -import numpy as np +import sys + import cv2 +import numpy as np import torch @@ -16,7 +17,6 @@ def read_pfm(path): tuple: (data, scale) """ with open(path, "rb") as file: - color = None width = None height = None @@ -74,9 +74,7 @@ def write_pfm(path, image, scale=1): if len(image.shape) == 3 and image.shape[2] == 3: # color image color = True - elif ( - len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 - ): # greyscale + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale color = False else: raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") @@ -135,9 +133,7 @@ def resize_image(img): img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) - img_resized = ( - torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() - ) + img_resized = torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() img_resized = img_resized.unsqueeze(0) return img_resized @@ -156,12 +152,11 @@ def resize_depth(depth, width, height): """ depth = torch.squeeze(depth[0, :, :, :]).to("cpu") - depth_resized = cv2.resize( - depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC - ) + depth_resized = cv2.resize(depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC) return depth_resized + def write_depth(path, depth, bits=1): """Write depth map to pfm and png file. @@ -174,7 +169,7 @@ def write_depth(path, depth, bits=1): depth_min = depth.min() depth_max = depth.max() - max_val = (2**(8*bits))-1 + max_val = (2 ** (8 * bits)) - 1 if depth_max - depth_min > np.finfo("float").eps: out = max_val * (depth - depth_min) / (depth_max - depth_min) diff --git a/examples/images/diffusion/ldm/util.py b/examples/images/diffusion/ldm/util.py index 8c09ca1c72f7..9b52b199aa2c 100644 --- a/examples/images/diffusion/ldm/util.py +++ b/examples/images/diffusion/ldm/util.py @@ -1,11 +1,10 @@ import importlib +from inspect import isfunction -import torch -from torch import optim import numpy as np - -from inspect import isfunction +import torch from PIL import Image, ImageDraw, ImageFont +from torch import optim def log_txt_as_img(wh, xc, size=10): @@ -16,9 +15,9 @@ def log_txt_as_img(wh, xc, size=10): for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) - font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) nc = int(40 * (wh[0] / 256)) - lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)) try: draw.text((0, 0), lines, fill="black", font=font) @@ -39,7 +38,7 @@ def ismap(x): def isimage(x): - if not isinstance(x,torch.Tensor): + if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) @@ -71,7 +70,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): if not "target" in config: - if config == '__is_first_stage__': + if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None @@ -89,9 +88,18 @@ def get_obj_from_str(string, reload=False): class AdamWwithEMAandWings(optim.Optimizer): # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 - def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using - weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code - ema_power=1., param_names=()): + def __init__( + self, + params, + lr=1.0e-3, + betas=(0.9, 0.999), + eps=1.0e-8, # TODO: check hyperparameters before using + weight_decay=1.0e-2, + amsgrad=False, + ema_decay=0.9999, # ema decay to match previous code + ema_power=1.0, + param_names=(), + ): """AdamW that saves EMA versions of the parameters.""" if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -105,15 +113,22 @@ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: che raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0.0 <= ema_decay <= 1.0: raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, - ema_power=ema_power, param_names=param_names) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ema_decay=ema_decay, + ema_power=ema_power, + param_names=param_names, + ) super().__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: - group.setdefault('amsgrad', False) + group.setdefault("amsgrad", False) @torch.no_grad() def step(self, closure=None): @@ -133,65 +148,66 @@ def step(self, closure=None): exp_avgs = [] exp_avg_sqs = [] ema_params_with_grad = [] - state_sums = [] max_exp_avg_sqs = [] state_steps = [] - amsgrad = group['amsgrad'] - beta1, beta2 = group['betas'] - ema_decay = group['ema_decay'] - ema_power = group['ema_power'] + amsgrad = group["amsgrad"] + beta1, beta2 = group["betas"] + ema_decay = group["ema_decay"] + ema_power = group["ema_power"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue params_with_grad.append(p) if p.grad.is_sparse: - raise RuntimeError('AdamW does not support sparse gradients') + raise RuntimeError("AdamW does not support sparse gradients") grads.append(p.grad) state = self.state[p] # State initialization if len(state) == 0: - state['step'] = 0 + state["step"] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of parameter values - state['param_exp_avg'] = p.detach().float().clone() + state["param_exp_avg"] = p.detach().float().clone() - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - ema_params_with_grad.append(state['param_exp_avg']) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + ema_params_with_grad.append(state["param_exp_avg"]) if amsgrad: - max_exp_avg_sqs.append(state['max_exp_avg_sq']) + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) # update the steps for each param group update - state['step'] += 1 + state["step"] += 1 # record the step after step update - state_steps.append(state['step']) - - optim._functional.adamw(params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps'], - maximize=False) - - cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + state_steps.append(state["step"]) + + optim._functional.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=False, + ) + + cur_ema_decay = min(ema_decay, 1 - state["step"] ** -ema_power) for param, ema_param in zip(params_with_grad, ema_params_with_grad): ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) - return loss \ No newline at end of file + return loss diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py index 713029fc677d..6d44df667fce 100644 --- a/examples/images/diffusion/main.py +++ b/examples/images/diffusion/main.py @@ -1,33 +1,28 @@ import argparse -import csv import datetime import glob -import importlib import os import sys import time +from functools import partial +import lightning.pytorch as pl import numpy as np import torch import torchvision -import lightning.pytorch as pl - - -from functools import partial - -from omegaconf import OmegaConf -from packaging import version -from PIL import Image -from prefetch_generator import BackgroundGenerator -from torch.utils.data import DataLoader, Dataset, Subset, random_split from ldm.models.diffusion.ddpm import LatentDiffusion - from lightning.pytorch import seed_everything from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger +from lightning.pytorch.strategies import ColossalAIStrategy, DDPStrategy from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities import rank_zero_info, rank_zero_only -from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger -from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from prefetch_generator import BackgroundGenerator +from torch.utils.data import DataLoader, Dataset + LIGHTNING_PACK_NAME = "lightning.pytorch." from ldm.data.base import Txt2ImgIterableBaseDataset @@ -37,15 +32,15 @@ class DataLoaderX(DataLoader): -# A custom data loader class that inherits from DataLoader + # A custom data loader class that inherits from DataLoader def __iter__(self): # Overriding the __iter__ method of DataLoader to return a BackgroundGenerator - #This is to enable data loading in the background to improve training performance + # This is to enable data loading in the background to improve training performance return BackgroundGenerator(super().__iter__()) def get_parser(**parser_kwargs): - #A function to create an ArgumentParser object and add arguments to it + # A function to create an ArgumentParser object and add arguments to it def str2bool(v): # A helper function to parse boolean values from command line arguments @@ -57,6 +52,7 @@ def str2bool(v): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") + # Create an ArgumentParser object with specifies kwargs parser = argparse.ArgumentParser(**parser_kwargs) @@ -160,6 +156,7 @@ def str2bool(v): return parser + # A function that returns the non-default arguments between two objects def nondefault_trainer_args(opt): # create an argument parser @@ -171,6 +168,7 @@ def nondefault_trainer_args(opt): # return all non-default arguments return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) + # A dataset wrapper class to create a pytorch dataset from an arbitrary object class WrappedDataset(Dataset): """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" @@ -184,6 +182,7 @@ def __len__(self): def __getitem__(self, idx): return self.data[idx] + # A function to initialize worker processes def worker_init_fn(_): worker_info = torch.utils.data.get_worker_info() @@ -192,31 +191,33 @@ def worker_init_fn(_): worker_id = worker_info.id if isinstance(dataset, Txt2ImgIterableBaseDataset): - #divide the dataset into equal parts for each worker + # divide the dataset into equal parts for each worker split_size = dataset.num_records // worker_info.num_workers - #set the sample IDs for the current worker + # set the sample IDs for the current worker # reset num_records to the true number to retain reliable length information - dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size] # set the seed for the current worker current_id = np.random.choice(len(np.random.get_state()[1]), 1) return np.random.seed(np.random.get_state()[1][current_id] + worker_id) else: return np.random.seed(np.random.get_state()[1][0] + worker_id) -#Provide functionality for creating data loaders based on provided dataset configurations -class DataModuleFromConfig(pl.LightningDataModule): - def __init__(self, - batch_size, - train=None, - validation=None, - test=None, - predict=None, - wrap=False, - num_workers=None, - shuffle_test_loader=False, - use_worker_init_fn=False, - shuffle_val_dataloader=False): +# Provide functionality for creating data loaders based on provided dataset configurations +class DataModuleFromConfig(pl.LightningDataModule): + def __init__( + self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=False, + ): super().__init__() # Set data module attributes self.batch_size = batch_size @@ -246,43 +247,47 @@ def prepare_data(self): def setup(self, stage=None): # Instantiate datasets from the dataset configs self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) - + # If wrap is true, create a WrappedDataset for each dataset if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) def _train_dataloader(self): - #Check if the train dataset is iterable - is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) - #Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True + # Check if the train dataset is iterable + is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) + # Set the worker initialization function of the dataset is iterable or use_worker_init_fn is True if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None # Return a DataLoaderX object for the train dataset - return DataLoaderX(self.datasets["train"], - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False if is_iterable_dataset else True, - worker_init_fn=init_fn) + return DataLoaderX( + self.datasets["train"], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False if is_iterable_dataset else True, + worker_init_fn=init_fn, + ) def _val_dataloader(self, shuffle=False): - #Check if the validation dataset is iterable - if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + # Check if the validation dataset is iterable + if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None # Return a DataLoaderX object for the validation dataset - return DataLoaderX(self.datasets["validation"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn, - shuffle=shuffle) + return DataLoaderX( + self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) def _test_dataloader(self, shuffle=False): # Check if the test dataset is iterable - is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) + is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) # Set the worker initialization function if the dataset is iterable or use_worker_init_fn is True if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn @@ -292,21 +297,22 @@ def _test_dataloader(self, shuffle=False): # do not shuffle dataloader for iterable dataset shuffle = shuffle and (not is_iterable_dataset) - return DataLoaderX(self.datasets["test"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn, - shuffle=shuffle) + return DataLoaderX( + self.datasets["test"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) def _predict_dataloader(self, shuffle=False): - if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: + if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None - return DataLoaderX(self.datasets["predict"], - batch_size=self.batch_size, - num_workers=self.num_workers, - worker_init_fn=init_fn) + return DataLoaderX( + self.datasets["predict"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn + ) class SetupCallback(Callback): @@ -338,10 +344,10 @@ def on_fit_start(self, trainer, pl_module): os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) - #Create trainstep checkpoint directory if necessary + # Create trainstep checkpoint directory if necessary if "callbacks" in self.lightning_config: - if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: - os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) + if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: + os.makedirs(os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True) print("Project config") print(OmegaConf.to_yaml(self.config)) OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) @@ -349,8 +355,10 @@ def on_fit_start(self, trainer, pl_module): # Save project config and lightning config as YAML files print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) - OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), - os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) + OmegaConf.save( + OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), + ) # Remove log directory if resuming training and directory already exists else: @@ -373,24 +381,25 @@ def on_fit_start(self, trainer, pl_module): # PyTorch Lightning callback for logging images during training and validation of a deep learning model class ImageLogger(Callback): - - def __init__(self, - batch_frequency, # Frequency of batches on which to log images - max_images, # Maximum number of images to log - clamp=True, # Whether to clamp pixel values to [-1,1] - increase_log_steps=True, # Whether to increase frequency of log steps exponentially - rescale=True, # Whether to rescale pixel values to [0,1] - disabled=False, # Whether to disable logging - log_on_batch_idx=False, # Whether to log on batch index instead of global step - log_first_step=False, # Whether to log on the first step - log_images_kwargs=None): # Additional keyword arguments to pass to log_images method + def __init__( + self, + batch_frequency, # Frequency of batches on which to log images + max_images, # Maximum number of images to log + clamp=True, # Whether to clamp pixel values to [-1,1] + increase_log_steps=True, # Whether to increase frequency of log steps exponentially + rescale=True, # Whether to rescale pixel values to [0,1] + disabled=False, # Whether to disable logging + log_on_batch_idx=False, # Whether to log on batch index instead of global step + log_first_step=False, # Whether to log on the first step + log_images_kwargs=None, + ): # Additional keyword arguments to pass to log_images method super().__init__() self.rescale = rescale self.batch_freq = batch_frequency self.max_images = max_images self.logger_log_images = { # Dictionary of logger classes and their corresponding logging methods - pl.loggers.CSVLogger: self._testtube, + pl.loggers.CSVLogger: self._testtube, } # Create a list of exponentially increasing log steps, starting from 1 and ending at batch_frequency self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] @@ -402,37 +411,39 @@ def __init__(self, self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} self.log_first_step = log_first_step - @rank_zero_only # Ensure that only the first process in distributed training executes this method - def _testtube(self, # The PyTorch Lightning module - pl_module, # A dictionary of images to log. - images, # - batch_idx, # The batch index. - split # The split (train/val) on which to log the images - ): - # Method for logging images using test-tube logger + @rank_zero_only # Ensure that only the first process in distributed training executes this method + def _testtube( + self, # The PyTorch Lightning module + pl_module, # A dictionary of images to log. + images, # + batch_idx, # The batch index. + split, # The split (train/val) on which to log the images + ): + # Method for logging images using test-tube logger for k in images: grid = torchvision.utils.make_grid(images[k]) - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w tag = f"{split}/{k}" # Add image grid to logger's experiment pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step) @rank_zero_only - def log_local(self, - save_dir, - split, # The split (train/val) on which to log the images - images, # A dictionary of images to log - global_step, # The global step - current_epoch, # The current epoch. - batch_idx - ): - # Method for saving image grids to local file system + def log_local( + self, + save_dir, + split, # The split (train/val) on which to log the images + images, # A dictionary of images to log + global_step, # The global step + current_epoch, # The current epoch. + batch_idx, + ): + # Method for saving image grids to local file system root = os.path.join(save_dir, "images", split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: - grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) @@ -443,11 +454,15 @@ def log_local(self, Image.fromarray(grid).save(path) def log_img(self, pl_module, batch, batch_idx, split="train"): - #Function for logging images to both the logger and local file system. + # Function for logging images to both the logger and local file system. check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step # check if it's time to log an image batch - if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 - hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): + if ( + self.check_frequency(check_idx) + and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 + and callable(pl_module.log_images) + and self.max_images > 0 + ): # Get logger type and check if training mode is on logger = type(pl_module.logger) @@ -466,11 +481,12 @@ def log_img(self, pl_module, batch, batch_idx, split="train"): if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().cpu() if self.clamp: - images[k] = torch.clamp(images[k], -1., 1.) + images[k] = torch.clamp(images[k], -1.0, 1.0) # Log images locally to file system - self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, - batch_idx) + self.log_local( + pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx + ) # log the images using the logger logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) @@ -482,13 +498,13 @@ def log_img(self, pl_module, batch, batch_idx, split="train"): # The function checks if it's time to log an image batch def check_frequency(self, check_idx): - if ((check_idx % self.batch_freq) == 0 or - (check_idx in self.log_steps)) and (check_idx > 0 or self.log_first_step): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step + ): try: self.log_steps.pop(0) except IndexError as e: print(e) - pass return True return False @@ -503,7 +519,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val") # log gradients during calibration if necessary - if hasattr(pl_module, 'calibrate_grad_norm'): + if hasattr(pl_module, "calibrate_grad_norm"): if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: self.log_gradients(trainer, pl_module, batch_idx=batch_idx) @@ -514,7 +530,7 @@ class CUDACallback(Callback): def on_train_start(self, trainer, pl_module): rank_zero_info("Training is starting") - #the method is called at the end of each training epoch + # the method is called at the end of each training epoch def on_train_end(self, trainer, pl_module): rank_zero_info("Training is ending") @@ -595,9 +611,11 @@ def on_train_epoch_end(self, trainer, pl_module): opt, unknown = parser.parse_known_args() # Verify the arguments are both specified if opt.name and opt.resume: - raise ValueError("-n/--name and -r/--resume cannot be specified both." - "If you want to resume training in a new log folder, " - "use -n/--name in combination with --resume_from_checkpoint") + raise ValueError( + "-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint" + ) # Check if the "resume" option is specified, resume training from the checkpoint if it is true ckpt = None @@ -646,7 +664,7 @@ def on_train_epoch_end(self, trainer, pl_module): # Sets the seed for the random number generator to ensure reproducibility seed_everything(opt.seed) - # Initialize and save configuration using teh OmegaConf library. + # Initialize and save configuration using teh OmegaConf library. try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] @@ -676,7 +694,7 @@ def on_train_epoch_end(self, trainer, pl_module): config.model["params"].update({"use_fp16": False}) if ckpt is not None: - #If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt + # If a checkpoint path is specified in the ckpt variable, the code updates the "ckpt" key in the "params" dictionary of the config.model configuration with the value of ckpt config.model["params"].update({"ckpt": ckpt}) rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) @@ -688,17 +706,12 @@ def on_train_epoch_end(self, trainer, pl_module): # Default logger configs to log training metrics during the training process. default_logger_cfgs = { "wandb": { - "name": nowname, - "save_dir": logdir, - "offline": opt.debug, - "id": nowname, - } - , - "tensorboard": { - "save_dir": logdir, - "name": "diff_tb", - "log_graph": True - } + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, + }, + "tensorboard": {"save_dir": logdir, "name": "diff_tb", "log_graph": True}, } # Set up the logger for TensorBoard @@ -722,11 +735,11 @@ def on_train_epoch_end(self, trainer, pl_module): # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { - "dirpath": ckptdir, - "filename": "{epoch:06}", - "verbose": True, - "save_last": True, - } + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + } if hasattr(model, "monitor"): default_modelckpt_cfg["monitor"] = model.monitor default_modelckpt_cfg["save_top_k"] = 3 @@ -736,48 +749,47 @@ def on_train_epoch_end(self, trainer, pl_module): else: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) - if version.parse(pl.__version__) < version.parse('1.4.0'): + if version.parse(pl.__version__) < version.parse("1.4.0"): trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg) - #Create an empty OmegaConf configuration object + # Create an empty OmegaConf configuration object callbacks_cfg = OmegaConf.create() - - #Instantiate items according to the configs + + # Instantiate items according to the configs trainer_kwargs.setdefault("callbacks", []) setup_callback_config = { - "resume": opt.resume, # resume training if applicable - "now": now, - "logdir": logdir, # directory to save the log file - "ckptdir": ckptdir, # directory to save the checkpoint file - "cfgdir": cfgdir, # directory to save the configuration file - "config": config, # configuration dictionary - "lightning_config": lightning_config, # LightningModule configuration - } + "resume": opt.resume, # resume training if applicable + "now": now, + "logdir": logdir, # directory to save the log file + "ckptdir": ckptdir, # directory to save the checkpoint file + "cfgdir": cfgdir, # directory to save the configuration file + "config": config, # configuration dictionary + "lightning_config": lightning_config, # LightningModule configuration + } trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config)) - + image_logger_config = { - - "batch_frequency": 750, # how frequently to log images - "max_images": 4, # maximum number of images to log - "clamp": True # whether to clamp pixel values to [0,1] - } + "batch_frequency": 750, # how frequently to log images + "max_images": 4, # maximum number of images to log + "clamp": True, # whether to clamp pixel values to [0,1] + } trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config)) - + learning_rate_logger_config = { - "logging_interval": "step", # logging frequency (either 'step' or 'epoch') - # "log_momentum": True # whether to log momentum (currently commented out) - } + "logging_interval": "step", # logging frequency (either 'step' or 'epoch') + # "log_momentum": True # whether to log momentum (currently commented out) + } trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config)) - - metrics_over_trainsteps_checkpoint_config= { - "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), + + metrics_over_trainsteps_checkpoint_config = { + "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), "filename": "{epoch:06}-{step:09}", "verbose": True, - 'save_top_k': -1, - 'every_n_train_steps': 10000, - 'save_weights_only': True - } + "save_top_k": -1, + "every_n_train_steps": 10000, + "save_weights_only": True, + } trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_trainsteps_checkpoint_config)) trainer_kwargs["callbacks"].append(CUDACallback()) @@ -805,7 +817,7 @@ def on_train_epoch_end(self, trainer, pl_module): ngpu = trainer_config["devices"] else: ngpu = 1 - if 'accumulate_grad_batches' in lightning_config.trainer: + if "accumulate_grad_batches" in lightning_config.trainer: accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches else: accumulate_grad_batches = 1 @@ -814,8 +826,10 @@ def on_train_epoch_end(self, trainer, pl_module): if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr rank_zero_info( - "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)" - .format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( + model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr + ) + ) else: model.learning_rate = base_lr rank_zero_info("++++ NOT USING LR SCALING ++++") @@ -832,9 +846,11 @@ def melk(*args, **kwargs): def divein(*args, **kwargs): if trainer.global_rank == 0: import pudb + pudb.set_trace() import signal + # Assign melk to SIGUSR1 signal and divein to SIGUSR2 signal signal.signal(signal.SIGUSR1, melk) signal.signal(signal.SIGUSR2, divein) diff --git a/examples/images/diffusion/scripts/download_first_stages.sh b/examples/images/diffusion/scripts/download_first_stages.sh index a8d79e99ccdf..50dab5de5b90 100755 --- a/examples/images/diffusion/scripts/download_first_stages.sh +++ b/examples/images/diffusion/scripts/download_first_stages.sh @@ -38,4 +38,4 @@ unzip -o model.zip cd ../vq-f16 unzip -o model.zip -cd ../.. \ No newline at end of file +cd ../.. diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py index 877538d4733d..4c386113dcc3 100644 --- a/examples/images/diffusion/scripts/img2img.py +++ b/examples/images/diffusion/scripts/img2img.py @@ -1,28 +1,30 @@ """make variations of input image""" -import argparse, os +import argparse +import os +from contextlib import nullcontext +from itertools import islice + +import numpy as np import PIL import torch -import numpy as np +from einops import rearrange, repeat from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange, repeat -from torchvision.utils import make_grid from torch import autocast -from contextlib import nullcontext +from torchvision.utils import make_grid +from tqdm import tqdm, trange + try: from lightning.pytorch import seed_everything except: from pytorch_lightning import seed_everything -from imwatermark import WatermarkEncoder - -from scripts.txt2img import put_watermark -from ldm.util import instantiate_from_config +from imwatermark import WatermarkEncoder from ldm.models.diffusion.ddim import DDIMSampler -from utils import replace_module, getModelSize +from ldm.util import instantiate_from_config +from scripts.txt2img import put_watermark +from utils import replace_module def chunk(it, size): @@ -58,7 +60,7 @@ def load_img(path): image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) - return 2. * image - 1. + return 2.0 * image - 1.0 def main(): @@ -69,22 +71,13 @@ def main(): type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render", ) - parser.add_argument( - "--init-img", - type=str, - nargs="?", - help="path to the input image" - ) + parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image") parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/img2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples" ) parser.add_argument( @@ -96,7 +89,7 @@ def main(): parser.add_argument( "--fixed_code", - action='store_true', + action="store_true", help="if enabled, uses the same starting code across all samples ", ) @@ -177,11 +170,7 @@ def main(): help="the seed (for reproducible sampling)", ) parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) parser.add_argument( "--use_int8", @@ -204,7 +193,7 @@ def main(): model = replace_module(model) # # to compute the model size # getModelSize(model) - + sampler = DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) @@ -213,7 +202,7 @@ def main(): print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") wm = "SDV2" wm_encoder = WatermarkEncoder() - wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + wm_encoder.set_watermark("bytes", wm.encode("utf-8")) batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size @@ -235,12 +224,12 @@ def main(): assert os.path.isfile(opt.init_img) init_image = load_img(opt.init_img).to(device) - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) - assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]" t_enc = int(opt.strength * opt.ddim_steps) print(f"target t_enc is {t_enc} steps") @@ -261,14 +250,19 @@ def main(): # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, ) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + ) x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") img = Image.fromarray(x_sample.astype(np.uint8)) img = put_watermark(img, wm_encoder) img.save(os.path.join(sample_path, f"{base_count:05}.png")) @@ -277,14 +271,14 @@ def main(): # additionally, save as grid grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() grid = Image.fromarray(grid.astype(np.uint8)) grid = put_watermark(grid, wm_encoder) - grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid.save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/images/diffusion/scripts/inpaint.py b/examples/images/diffusion/scripts/inpaint.py index d6e6387a9a3b..afffcf1685e6 100644 --- a/examples/images/diffusion/scripts/inpaint.py +++ b/examples/images/diffusion/scripts/inpaint.py @@ -1,32 +1,35 @@ -import argparse, os, sys, glob -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm +import argparse +import glob +import os + import numpy as np import torch -from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler +from main import instantiate_from_config +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm def make_batch(image, mask, device): image = np.array(Image.open(image).convert("RGB")) - image = image.astype(np.float32)/255.0 - image = image[None].transpose(0,3,1,2) + image = image.astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) mask = np.array(Image.open(mask).convert("L")) - mask = mask.astype(np.float32)/255.0 - mask = mask[None,None] + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) - masked_image = (1-mask)*image + masked_image = (1 - mask) * image batch = {"image": image, "mask": mask, "masked_image": masked_image} for k in batch: batch[k] = batch[k].to(device=device) - batch[k] = batch[k]*2.0-1.0 + batch[k] = batch[k] * 2.0 - 1.0 return batch @@ -58,8 +61,7 @@ def make_batch(image, mask, device): config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") model = instantiate_from_config(config.model) - model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], - strict=False) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) @@ -74,25 +76,19 @@ def make_batch(image, mask, device): # encode masked image and concat downsampled mask c = model.cond_stage_model.encode(batch["masked_image"]) - cc = torch.nn.functional.interpolate(batch["mask"], - size=c.shape[-2:]) + cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:]) c = torch.cat((c, cc), dim=1) - shape = (c.shape[1]-1,)+c.shape[2:] - samples_ddim, _ = sampler.sample(S=opt.steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape, - verbose=False) + shape = (c.shape[1] - 1,) + c.shape[2:] + samples_ddim, _ = sampler.sample( + S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False + ) x_samples_ddim = model.decode_first_stage(samples_ddim) - image = torch.clamp((batch["image"]+1.0)/2.0, - min=0.0, max=1.0) - mask = torch.clamp((batch["mask"]+1.0)/2.0, - min=0.0, max=1.0) - predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, - min=0.0, max=1.0) + image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0) + mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - inpainted = (1-mask)*image+mask*predicted_image - inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + inpainted = (1 - mask) * image + mask * predicted_image + inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/examples/images/diffusion/scripts/knn2img.py b/examples/images/diffusion/scripts/knn2img.py index e6eaaecab53e..763811665bbc 100644 --- a/examples/images/diffusion/scripts/knn2img.py +++ b/examples/images/diffusion/scripts/knn2img.py @@ -1,22 +1,22 @@ -import argparse, os, sys, glob -import clip -import torch -import torch.nn as nn -import numpy as np -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange, repeat -from torchvision.utils import make_grid -import scann +import argparse +import glob +import os import time +from itertools import islice from multiprocessing import cpu_count -from ldm.util import instantiate_from_config, parallel_data_prefetch +import numpy as np +import scann +import torch +from einops import rearrange from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder +from ldm.util import instantiate_from_config, parallel_data_prefetch +from omegaconf import OmegaConf +from PIL import Image +from torchvision.utils import make_grid +from tqdm import tqdm, trange DATABASES = [ "openimages", @@ -59,29 +59,24 @@ def load_model_from_config(config, ckpt, verbose=False): class Searcher(object): - def __init__(self, database, retriever_version='ViT-L/14'): + def __init__(self, database, retriever_version="ViT-L/14"): assert database in DATABASES # self.database = self.load_database(database) self.database_name = database - self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' - self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.searcher_savedir = f"data/rdm/searchers/{self.database_name}" + self.database_path = f"data/rdm/retrieval_databases/{self.database_name}" self.retriever = self.load_retriever(version=retriever_version) - self.database = {'embedding': [], - 'img_id': [], - 'patch_coords': []} + self.database = {"embedding": [], "img_id": [], "patch_coords": []} self.load_database() self.load_searcher() - def train_searcher(self, k, - metric='dot_product', - searcher_savedir=None): - - print('Start training searcher') - searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / - np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], - k, metric) + def train_searcher(self, k, metric="dot_product", searcher_savedir=None): + print("Start training searcher") + searcher = scann.scann_ops_pybind.builder( + self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric + ) self.searcher = searcher.score_brute_force().build() - print('Finish training searcher') + print("Finish training searcher") if searcher_savedir is not None: print(f'Save trained searcher under "{searcher_savedir}"') @@ -91,36 +86,40 @@ def train_searcher(self, k, def load_single_file(self, saved_embeddings): compressed = np.load(saved_embeddings) self.database = {key: compressed[key] for key in compressed.files} - print('Finished loading of clip embeddings.') + print("Finished loading of clip embeddings.") def load_multi_files(self, data_archive): out_data = {key: [] for key in self.database} - for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."): for key in d.files: out_data[key].append(d[key]) return out_data def load_database(self): - print(f'Load saved patch embedding from "{self.database_path}"') - file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + file_content = glob.glob(os.path.join(self.database_path, "*.npz")) if len(file_content) == 1: self.load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(self.load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') + prefetched_data = parallel_data_prefetch( + self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict" + ) - self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in - self.database} + self.database = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database + } else: raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') - def load_retriever(self, version='ViT-L/14', ): + def load_retriever( + self, + version="ViT-L/14", + ): model = FrozenClipImageEmbedder(model=version) if torch.cuda.is_available(): model.cuda() @@ -128,14 +127,14 @@ def load_retriever(self, version='ViT-L/14', ): return model def load_searcher(self): - print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + print(f"load searcher for database {self.database_name} from {self.searcher_savedir}") self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) - print('Finished loading searcher.') + print("Finished loading searcher.") def search(self, x, k): - if self.searcher is None and self.database['embedding'].shape[0] < 2e4: - self.train_searcher(k) # quickly fit searcher on the fly for small databases - assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if self.searcher is None and self.database["embedding"].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, "Cannot search with uninitialized searcher" if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() if len(x.shape) == 3: @@ -146,17 +145,19 @@ def search(self, x, k): nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) end = time.time() - out_embeddings = self.database['embedding'][nns] - out_img_ids = self.database['img_id'][nns] - out_pc = self.database['patch_coords'][nns] + out_embeddings = self.database["embedding"][nns] + out_img_ids = self.database["img_id"][nns] + out_pc = self.database["patch_coords"][nns] - out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], - 'img_ids': out_img_ids, - 'patch_coords': out_pc, - 'queries': x, - 'exec_time': end - start, - 'nns': nns, - 'q_embeddings': query_embeddings} + out = { + "nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + "img_ids": out_img_ids, + "patch_coords": out_pc, + "queries": x, + "exec_time": end - start, + "nns": nns, + "q_embeddings": query_embeddings, + } return out @@ -173,20 +174,16 @@ def __call__(self, x, n): type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render", ) parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--skip_grid", - action='store_true', + action="store_true", help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", ) @@ -206,7 +203,7 @@ def __call__(self, x, n): parser.add_argument( "--plms", - action='store_true', + action="store_true", help="use plms sampling", ) @@ -287,14 +284,14 @@ def __call__(self, x, n): parser.add_argument( "--database", type=str, - default='artbench-surrealism', + default="artbench-surrealism", choices=DATABASES, help="The database used for the search, only applied when --use_neighbors=True", ) parser.add_argument( "--use_neighbors", default=False, - action='store_true', + action="store_true", help="Include neighbors in addition to text prompt for conditioning", ) parser.add_argument( @@ -358,41 +355,43 @@ def __call__(self, x, n): uc = None if searcher is not None: nn_dict = searcher(c, opt.knn) - c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1) if opt.scale != 1.0: uc = torch.zeros_like(c) if isinstance(prompts, tuple): prompts = list(prompts) shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - ) + samples_ddim, _ = sampler.sample( + S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) + os.path.join(sample_path, f"{base_count:05}.png") + ) base_count += 1 all_samples.append(x_samples_ddim) if not opt.skip_grid: # additionally, save as grid grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/examples/images/diffusion/scripts/sample_diffusion.py b/examples/images/diffusion/scripts/sample_diffusion.py index 876fe3c3642f..740aae2435d2 100644 --- a/examples/images/diffusion/scripts/sample_diffusion.py +++ b/examples/images/diffusion/scripts/sample_diffusion.py @@ -1,21 +1,26 @@ -import argparse, os, sys, glob, datetime, yaml -import torch +import argparse +import datetime +import glob +import os +import sys import time -import numpy as np -from tqdm import trange +import numpy as np +import torch +import yaml +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config from omegaconf import OmegaConf from PIL import Image +from tqdm import trange -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.util import instantiate_from_config +rescale = lambda x: (x + 1.0) / 2.0 -rescale = lambda x: (x + 1.) / 2. def custom_to_pil(x): x = x.detach().cpu() - x = torch.clamp(x, -1., 1.) - x = (x + 1.) / 2. + x = torch.clamp(x, -1.0, 1.0) + x = (x + 1.0) / 2.0 x = x.permute(1, 2, 0).numpy() x = (255 * x).astype(np.uint8) x = Image.fromarray(x) @@ -51,49 +56,51 @@ def logs2pil(logs, keys=["sample"]): @torch.no_grad() -def convsample(model, shape, return_intermediates=True, - verbose=True, - make_prog_row=False): - - +def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False): if not make_prog_row: - return model.p_sample_loop(None, shape, - return_intermediates=return_intermediates, verbose=verbose) + return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose) else: - return model.progressive_denoising( - None, shape, verbose=True - ) + return model.progressive_denoising(None, shape, verbose=True) @torch.no_grad() -def convsample_ddim(model, steps, shape, eta=1.0 - ): +def convsample_ddim(model, steps, shape, eta=1.0): ddim = DDIMSampler(model) bs = shape[0] shape = shape[1:] - samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + samples, intermediates = ddim.sample( + steps, + batch_size=bs, + shape=shape, + eta=eta, + verbose=False, + ) return samples, intermediates @torch.no_grad() -def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): - - +def make_convolutional_sample( + model, + batch_size, + vanilla=False, + custom_steps=None, + eta=1.0, +): log = dict() - shape = [batch_size, - model.model.diffusion_model.in_channels, - model.model.diffusion_model.image_size, - model.model.diffusion_model.image_size] + shape = [ + batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size, + ] with model.ema_scope("Plotting"): t0 = time.time() if vanilla: - sample, progrow = convsample(model, shape, - make_prog_row=True) + sample, progrow = convsample(model, shape, make_prog_row=True) else: - sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, - eta=eta) + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta) t1 = time.time() @@ -101,32 +108,32 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non log["sample"] = x_sample log["time"] = t1 - t0 - log['throughput'] = sample.shape[0] / (t1 - t0) + log["throughput"] = sample.shape[0] / (t1 - t0) print(f'Throughput for this batch: {log["throughput"]}') return log + def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): if vanilla: - print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.") else: - print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') - + print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}") tstart = time.time() - n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1 # path = logdir if model.cond_stage_model is None: all_images = [] print(f"Running unconditional sampling for {n_samples} samples") for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): - logs = make_convolutional_sample(model, batch_size=batch_size, - vanilla=vanilla, custom_steps=custom_steps, - eta=eta) + logs = make_convolutional_sample( + model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta + ) n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") all_images.extend([custom_to_np(logs["sample"])]) if n_saved >= n_samples: - print(f'Finish after generating {n_saved} samples') + print(f"Finish after generating {n_saved} samples") break all_img = np.concatenate(all_images, axis=0) all_img = all_img[:n_samples] @@ -135,7 +142,7 @@ def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None np.savez(nppath, all_img) else: - raise NotImplementedError('Currently only sampling for unconditional models supported.') + raise NotImplementedError("Currently only sampling for unconditional models supported.") print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") @@ -168,58 +175,33 @@ def get_parser(): nargs="?", help="load from logdir or checkpoint in logdir", ) - parser.add_argument( - "-n", - "--n_samples", - type=int, - nargs="?", - help="number of samples to draw", - default=50000 - ) + parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000) parser.add_argument( "-e", "--eta", type=float, nargs="?", help="eta for ddim sampling (0.0 yields deterministic sampling)", - default=1.0 + default=1.0, ) parser.add_argument( "-v", "--vanilla_sample", default=False, - action='store_true', + action="store_true", help="vanilla sampling (default option is DDIM sampling)?", ) + parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none") parser.add_argument( - "-l", - "--logdir", - type=str, - nargs="?", - help="extra logdir", - default="none" - ) - parser.add_argument( - "-c", - "--custom_steps", - type=int, - nargs="?", - help="number of steps for ddim and fastdpm sampling", - default=50 - ) - parser.add_argument( - "--batch_size", - type=int, - nargs="?", - help="the bs", - default=10 + "-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50 ) + parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10) return parser def load_model_from_config(config, sd): model = instantiate_from_config(config) - model.load_state_dict(sd,strict=False) + model.load_state_dict(sd, strict=False) model.cuda() model.eval() return model @@ -233,8 +215,7 @@ def load_model(config, ckpt, gpu, eval_mode): else: pl_sd = {"state_dict": None} global_step = None - model = load_model_from_config(config.model, - pl_sd["state_dict"]) + model = load_model_from_config(config.model, pl_sd["state_dict"]) return model, global_step @@ -253,9 +234,9 @@ def load_model(config, ckpt, gpu, eval_mode): if os.path.isfile(opt.resume): # paths = opt.resume.split("/") try: - logdir = '/'.join(opt.resume.split('/')[:-1]) + logdir = "/".join(opt.resume.split("/")[:-1]) # idx = len(paths)-paths[::-1].index("logs")+1 - print(f'Logdir is {logdir}') + print(f"Logdir is {logdir}") except ValueError: paths = opt.resume.split("/") idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt @@ -278,7 +259,8 @@ def load_model(config, ckpt, gpu, eval_mode): if opt.logdir != "none": locallog = logdir.split(os.sep)[-1] - if locallog == "": locallog = logdir.split(os.sep)[-2] + if locallog == "": + locallog = logdir.split(os.sep)[-2] print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") logdir = os.path.join(opt.logdir, locallog) @@ -301,13 +283,19 @@ def load_model(config, ckpt, gpu, eval_mode): sampling_file = os.path.join(logdir, "sampling_config.yaml") sampling_conf = vars(opt) - with open(sampling_file, 'w') as f: + with open(sampling_file, "w") as f: yaml.dump(sampling_conf, f, default_flow_style=False) print(sampling_conf) - - run(model, imglogdir, eta=opt.eta, - vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, - batch_size=opt.batch_size, nplog=numpylogdir) + run( + model, + imglogdir, + eta=opt.eta, + vanilla=opt.vanilla_sample, + n_samples=opt.n_samples, + custom_steps=opt.custom_steps, + batch_size=opt.batch_size, + nplog=numpylogdir, + ) print("done.") diff --git a/examples/images/diffusion/scripts/tests/test_checkpoint.py b/examples/images/diffusion/scripts/tests/test_checkpoint.py index 13622c4989fd..c0af17bdecaa 100644 --- a/examples/images/diffusion/scripts/tests/test_checkpoint.py +++ b/examples/images/diffusion/scripts/tests/test_checkpoint.py @@ -1,28 +1,18 @@ -import os -import sys -from copy import deepcopy - +import torch import yaml -from datetime import datetime - from diffusers import StableDiffusionPipeline -import torch - -from main import get_parser from ldm.modules.diffusionmodules.openaimodel import UNetModel if __name__ == "__main__": with torch.no_grad(): yaml_path = "../../train_colossalai.yaml" - with open(yaml_path, 'r', encoding='utf-8') as f: + with open(yaml_path, "r", encoding="utf-8") as f: config = f.read() base_config = yaml.load(config, Loader=yaml.FullLoader) - unet_config = base_config['model']['params']['unet_config'] + unet_config = base_config["model"]["params"]["unet_config"] diffusion_model = UNetModel(**unet_config).to("cuda:0") - pipe = StableDiffusionPipeline.from_pretrained( - "/data/scratch/diffuser/stable-diffusion-v1-4" - ).to("cuda:0") + pipe = StableDiffusionPipeline.from_pretrained("/data/scratch/diffuser/stable-diffusion-v1-4").to("cuda:0") dif_model_2 = pipe.unet random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0") @@ -35,4 +25,4 @@ out_1 = diffusion_model(random_input_, time_stamp, context_) out_2 = dif_model_2(random_input_2, time_stamp2, context_2) print(out_1.shape) - print(out_2['sample'].shape) \ No newline at end of file + print(out_2["sample"].shape) diff --git a/examples/images/diffusion/scripts/tests/test_watermark.py b/examples/images/diffusion/scripts/tests/test_watermark.py index f93f8a6e7076..9bfc9fc7d9cb 100644 --- a/examples/images/diffusion/scripts/tests/test_watermark.py +++ b/examples/images/diffusion/scripts/tests/test_watermark.py @@ -5,14 +5,14 @@ def testit(img_path): bgr = cv2.imread(img_path) - decoder = WatermarkDecoder('bytes', 136) - watermark = decoder.decode(bgr, 'dwtDct') + decoder = WatermarkDecoder("bytes", 136) + watermark = decoder.decode(bgr, "dwtDct") try: - dec = watermark.decode('utf-8') + dec = watermark.decode("utf-8") except: dec = "null" print(dec) if __name__ == "__main__": - fire.Fire(testit) \ No newline at end of file + fire.Fire(testit) diff --git a/examples/images/diffusion/scripts/train_searcher.py b/examples/images/diffusion/scripts/train_searcher.py index 1e7904889c01..1df0baa7e5cf 100644 --- a/examples/images/diffusion/scripts/train_searcher.py +++ b/examples/images/diffusion/scripts/train_searcher.py @@ -1,33 +1,39 @@ -import os, sys -import numpy as np -import scann import argparse import glob +import os +import sys from multiprocessing import cpu_count -from tqdm import tqdm +import numpy as np +import scann from ldm.util import parallel_data_prefetch +from tqdm import tqdm def search_bruteforce(searcher): return searcher.score_brute_force().build() -def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search): - return searcher.tree(num_leaves=num_leaves, - num_leaves_to_search=num_leaves_to_search, - training_sample_size=partioning_trainsize). \ - score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() +def search_partioned_ah( + searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search +): + return ( + searcher.tree( + num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize + ) + .score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold) + .reorder(reorder_k) + .build() + ) def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): - return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( - reorder_k).build() - -def load_datapool(dpath): + return ( + searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + ) +def load_datapool(dpath): def load_single_file(saved_embeddings): compressed = np.load(saved_embeddings) database = {key: compressed[key] for key in compressed.files} @@ -35,23 +41,26 @@ def load_single_file(saved_embeddings): def load_multi_files(data_archive): database = {key: [] for key in data_archive[0].files} - for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."): for key in d.files: database[key].append(d[key]) return database print(f'Load saved patch embedding from "{dpath}"') - file_content = glob.glob(os.path.join(dpath, '*.npz')) + file_content = glob.glob(os.path.join(dpath, "*.npz")) if len(file_content) == 1: data_pool = load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') + prefetched_data = parallel_data_prefetch( + load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict" + ) - data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + data_pool = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys() + } else: raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') @@ -59,16 +68,17 @@ def load_multi_files(data_archive): return data_pool -def train_searcher(opt, - metric='dot_product', - partioning_trainsize=None, - reorder_k=None, - # todo tune - aiq_thld=0.2, - dims_per_block=2, - num_leaves=None, - num_leaves_to_search=None,): - +def train_searcher( + opt, + metric="dot_product", + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None, +): data_pool = load_datapool(opt.database) k = opt.knn @@ -77,71 +87,83 @@ def train_searcher(opt, # normalize # embeddings = - searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) - pool_size = data_pool['embedding'].shape[0] - - print(*(['#'] * 100)) - print('Initializing scaNN searcher with the following values:') - print(f'k: {k}') - print(f'metric: {metric}') - print(f'reorder_k: {reorder_k}') - print(f'anisotropic_quantization_threshold: {aiq_thld}') - print(f'dims_per_block: {dims_per_block}') - print(*(['#'] * 100)) - print('Start training searcher....') - print(f'N samples in pool is {pool_size}') + searcher = scann.scann_ops_pybind.builder( + data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric + ) + pool_size = data_pool["embedding"].shape[0] + + print(*(["#"] * 100)) + print("Initializing scaNN searcher with the following values:") + print(f"k: {k}") + print(f"metric: {metric}") + print(f"reorder_k: {reorder_k}") + print(f"anisotropic_quantization_threshold: {aiq_thld}") + print(f"dims_per_block: {dims_per_block}") + print(*(["#"] * 100)) + print("Start training searcher....") + print(f"N samples in pool is {pool_size}") # this reflects the recommended design choices proposed at # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md if pool_size < 2e4: - print('Using brute force search.') + print("Using brute force search.") searcher = search_bruteforce(searcher) elif 2e4 <= pool_size and pool_size < 1e5: - print('Using asymmetric hashing search and reordering.') + print("Using asymmetric hashing search and reordering.") searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) else: - print('Using using partioning, asymmetric hashing search and reordering.') + print("Using using partioning, asymmetric hashing search and reordering.") if not partioning_trainsize: - partioning_trainsize = data_pool['embedding'].shape[0] // 10 + partioning_trainsize = data_pool["embedding"].shape[0] // 10 if not num_leaves: num_leaves = int(np.sqrt(pool_size)) if not num_leaves_to_search: num_leaves_to_search = max(num_leaves // 20, 1) - print('Partitioning params:') - print(f'num_leaves: {num_leaves}') - print(f'num_leaves_to_search: {num_leaves_to_search}') + print("Partitioning params:") + print(f"num_leaves: {num_leaves}") + print(f"num_leaves_to_search: {num_leaves_to_search}") # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) - searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search) + searcher = search_partioned_ah( + searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search + ) - print('Finish training searcher') + print("Finish training searcher") searcher_savedir = opt.target_path os.makedirs(searcher_savedir, exist_ok=True) searcher.serialize(searcher_savedir) print(f'Saved trained searcher under "{searcher_savedir}"') -if __name__ == '__main__': + +if __name__ == "__main__": sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() - parser.add_argument('--database', - '-d', - default='data/rdm/retrieval_databases/openimages', - type=str, - help='path to folder containing the clip feature of the database') - parser.add_argument('--target_path', - '-t', - default='data/rdm/searchers/openimages', - type=str, - help='path to the target folder where the searcher shall be stored.') - parser.add_argument('--knn', - '-k', - default=20, - type=int, - help='number of nearest neighbors, for which the searcher shall be optimized') - - opt, _ = parser.parse_known_args() - - train_searcher(opt,) \ No newline at end of file + parser.add_argument( + "--database", + "-d", + default="data/rdm/retrieval_databases/openimages", + type=str, + help="path to folder containing the clip feature of the database", + ) + parser.add_argument( + "--target_path", + "-t", + default="data/rdm/searchers/openimages", + type=str, + help="path to the target folder where the searcher shall be stored.", + ) + parser.add_argument( + "--knn", + "-k", + default=20, + type=int, + help="number of nearest neighbors, for which the searcher shall be optimized", + ) + + opt, _ = parser.parse_known_args() + + train_searcher( + opt, + ) diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py index 364ebac6c67b..feb17b9f77ae 100644 --- a/examples/images/diffusion/scripts/txt2img.py +++ b/examples/images/diffusion/scripts/txt2img.py @@ -1,29 +1,34 @@ -import argparse, os +import argparse +import os +from itertools import islice + import cv2 -import torch import numpy as np +import torch +from einops import rearrange from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange from torchvision.utils import make_grid +from tqdm import tqdm, trange + try: from lightning.pytorch import seed_everything except: from pytorch_lightning import seed_everything -from torch import autocast + from contextlib import nullcontext -from imwatermark import WatermarkEncoder -from ldm.util import instantiate_from_config +from imwatermark import WatermarkEncoder from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from utils import replace_module, getModelSize +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config +from torch import autocast +from utils import replace_module torch.set_grad_enabled(False) + def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) @@ -55,14 +60,10 @@ def parse_args(): type=str, nargs="?", default="a professional photograph of an astronaut riding a triceratops", - help="the prompt to render" + help="the prompt to render", ) parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--steps", @@ -72,17 +73,17 @@ def parse_args(): ) parser.add_argument( "--plms", - action='store_true', + action="store_true", help="use plms sampling", ) parser.add_argument( "--dpm", - action='store_true', + action="store_true", help="use DPM (2) sampler", ) parser.add_argument( "--fixed_code", - action='store_true', + action="store_true", help="if enabled, uses the same starting code across all samples ", ) parser.add_argument( @@ -162,11 +163,7 @@ def parse_args(): help="the seed (for reproducible sampling)", ) parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) parser.add_argument( "--repeat", @@ -187,7 +184,7 @@ def parse_args(): def put_watermark(img, wm_encoder=None): if wm_encoder is not None: img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) - img = wm_encoder.encode(img, 'dwtDct') + img = wm_encoder.encode(img, "dwtDct") img = Image.fromarray(img[:, :, ::-1]) return img @@ -197,17 +194,17 @@ def main(opt): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") - + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) - + # quantize model if opt.use_int8: model = replace_module(model) # # to compute the model size # getModelSize(model) - + if opt.plms: sampler = PLMSSampler(model) elif opt.dpm: @@ -221,7 +218,7 @@ def main(opt): print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") wm = "SDV2" wm_encoder = WatermarkEncoder() - wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + wm_encoder.set_watermark("bytes", wm.encode("utf-8")) batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size @@ -248,56 +245,55 @@ def main(opt): start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) precision_scope = autocast if opt.precision == "autocast" else nullcontext - with torch.no_grad(), \ - precision_scope("cuda"), \ - model.ema_scope(): - all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples, _ = sampler.sample(S=opt.steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) - - x_samples = model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - img = Image.fromarray(x_sample.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 - sample_count += 1 - - all_samples.append(x_samples) - - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) - - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - grid = Image.fromarray(grid.astype(np.uint8)) - grid = put_watermark(grid, wm_encoder) - grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 - - print(f"Your samples are ready and waiting for you here: \n{outpath} \n" - f" \nEnjoy.") + with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples, _ = sampler.sample( + S=opt.steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code, + ) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + sample_count += 1 + + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, "n b c h w -> (n b) c h w") + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f"grid-{grid_count:04}.png")) + grid_count += 1 + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") if __name__ == "__main__": diff --git a/examples/images/diffusion/scripts/utils.py b/examples/images/diffusion/scripts/utils.py index c954b22ca190..92ed0b4dfd0a 100644 --- a/examples/images/diffusion/scripts/utils.py +++ b/examples/images/diffusion/scripts/utils.py @@ -1,6 +1,7 @@ import bitsandbytes as bnb -import torch.nn as nn import torch +import torch.nn as nn + class Linear8bit(nn.Linear): def __init__( @@ -12,11 +13,9 @@ def __init__( memory_efficient_backward=False, threshold=6.0, weight_data=None, - bias_data=None + bias_data=None, ): - super(Linear8bit, self).__init__( - input_features, output_features, bias - ) + super(Linear8bit, self).__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() self.bias = bias_data self.state.threshold = threshold @@ -24,13 +23,12 @@ def __init__( self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - + self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False)) self.weight = weight_data self.quant() - - def quant(self): + def quant(self): weight = self.weight.data.contiguous().half().cuda() CB, _, SCB, _, _ = bnb.functional.double_quant(weight) delattr(self, "weight") @@ -41,32 +39,34 @@ def quant(self): def forward(self, x): self.state.is_training = self.training - + if self.bias is not None and self.bias.dtype != torch.float16: self.bias.data = self.bias.data.half() - + self.state.CB = self.weight.data self.state.SCB = self.SCB.data - + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) del self.state.CxB return out + def replace_module(model): for name, module in model.named_children(): if len(list(module.children())) > 0: replace_module(module) - if isinstance(module, nn.Linear) and "out_proj" not in name: + if isinstance(module, nn.Linear) and "out_proj" not in name: model._modules[name] = Linear8bit( - input_features=module.in_features, - output_features=module.out_features, - threshold=6.0, - weight_data=module.weight, - bias_data=module.bias, - ) + input_features=module.in_features, + output_features=module.out_features, + threshold=6.0, + weight_data=module.weight, + bias_data=module.bias, + ) return model + def getModelSize(model): param_size = 0 param_sum = 0 @@ -79,5 +79,5 @@ def getModelSize(model): buffer_size += buffer.nelement() * buffer.element_size() buffer_sum += buffer.nelement() all_size = (param_size + buffer_size) / 1024 / 1024 - print('Model Size: {:.3f}MB'.format(all_size)) + print("Model Size: {:.3f}MB".format(all_size)) return (param_size, param_sum, buffer_size, buffer_sum, all_size) diff --git a/examples/images/diffusion/setup.py b/examples/images/diffusion/setup.py index a24d54167640..13d9f8927801 100644 --- a/examples/images/diffusion/setup.py +++ b/examples/images/diffusion/setup.py @@ -1,13 +1,13 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( - name='latent-diffusion', - version='0.0.1', - description='', + name="latent-diffusion", + version="0.0.1", + description="", packages=find_packages(), install_requires=[ - 'torch', - 'numpy', - 'tqdm', + "torch", + "numpy", + "tqdm", ], -) \ No newline at end of file +) diff --git a/examples/images/diffusion/train_colossalai.sh b/examples/images/diffusion/train_colossalai.sh index 7f1a1bd14615..c56ed7876e5a 100755 --- a/examples/images/diffusion/train_colossalai.sh +++ b/examples/images/diffusion/train_colossalai.sh @@ -3,4 +3,3 @@ TRANSFORMERS_OFFLINE=1 DIFFUSERS_OFFLINE=1 python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt - diff --git a/examples/images/diffusion/train_ddp.sh b/examples/images/diffusion/train_ddp.sh index 78fe765488c6..8304d6fa8b4f 100644 --- a/examples/images/diffusion/train_ddp.sh +++ b/examples/images/diffusion/train_ddp.sh @@ -1,5 +1,5 @@ -HF_DATASETS_OFFLINE=1 -TRANSFORMERS_OFFLINE=1 -DIFFUSERS_OFFLINE=1 +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 python main.py --logdir /tmp -t -b /configs/train_ddp.yaml diff --git a/examples/images/dreambooth/README.md b/examples/images/dreambooth/README.md index ba4c1a71034a..4e9febbc5fa8 100644 --- a/examples/images/dreambooth/README.md +++ b/examples/images/dreambooth/README.md @@ -93,7 +93,7 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \ ``` ## New API -We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`. +We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`. We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster. For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/. @@ -111,7 +111,7 @@ For more information about the booster API you can refer to https://colossalai.o | low_level_zero | 4 | 8 | 28.87 | 2.02 | The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink. -We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared +We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared the memory cost and the throughput for the plugins. diff --git a/examples/images/dreambooth/debug.py b/examples/images/dreambooth/debug.py index 33219b2caa29..8ce4dc3bbd80 100644 --- a/examples/images/dreambooth/debug.py +++ b/examples/images/dreambooth/debug.py @@ -1,16 +1,16 @@ -''' +""" torchrun --standalone --nproc_per_node=1 debug.py -''' +""" from diffusers import AutoencoderKL import colossalai -from colossalai.zero import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext path = "/data/scratch/diffuser/stable-diffusion-v1-4" colossalai.launch_from_torch(config={}) -with ColoInitContext(device='cpu'): +with ColoInitContext(device="cpu"): vae = AutoencoderKL.from_pretrained( path, subfolder="vae", diff --git a/examples/images/dreambooth/inference.py b/examples/images/dreambooth/inference.py index c342821c7830..ff317827aff7 100644 --- a/examples/images/dreambooth/inference.py +++ b/examples/images/dreambooth/inference.py @@ -1,7 +1,7 @@ -from diffusers import StableDiffusionPipeline, DiffusionPipeline import torch +from diffusers import DiffusionPipeline -model_id = +model_id = "" print(f"Loading model... from{model_id}") pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") diff --git a/examples/images/dreambooth/train_dreambooth.py b/examples/images/dreambooth/train_dreambooth.py index b989955f7fb7..9b66089b2752 100644 --- a/examples/images/dreambooth/train_dreambooth.py +++ b/examples/images/dreambooth/train_dreambooth.py @@ -104,8 +104,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -118,17 +120,18 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) - parser.add_argument("--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution") parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -165,16 +168,17 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), - ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -192,8 +196,10 @@ def parse_args(input_args=None): "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -203,7 +209,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -269,12 +276,14 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -350,7 +359,8 @@ def main(args): if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) if args.seed is not None: set_seed(args.seed) @@ -380,9 +390,9 @@ def main(args): sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - for example in tqdm(sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process): + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): images = pipeline(example["prompt"]).images for i, image in enumerate(images): @@ -456,8 +466,9 @@ def main(args): text_encoder.gradient_checkpointing_enable() if args.scale_lr: - args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * - accelerator.num_processes) + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -470,8 +481,9 @@ def main(args): else: optimizer_class = torch.optim.AdamW - params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -506,9 +518,7 @@ def collate_fn(examples): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -520,11 +530,9 @@ def collate_fn(examples): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -542,10 +550,12 @@ def collate_fn(examples): if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, - lr_scheduler) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": @@ -641,8 +651,11 @@ def collate_fn(examples): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 9b2ed3b971ae..1a7f8da7f7d0 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -117,8 +117,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -131,8 +133,10 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), ) parser.add_argument( "--offload_optim_frac", @@ -144,13 +148,14 @@ def parse_args(input_args=None): "--center_crop", default=False, action="store_true", - help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping."), + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -181,16 +186,17 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -202,18 +208,22 @@ def parse_args(input_args=None): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -223,7 +233,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -292,12 +303,14 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -391,9 +404,9 @@ def main(args): pipeline.to(get_current_device()) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not local_rank == 0, + sample_dataloader, + desc="Generating class images", + disable=not local_rank == 0, ): images = pipeline(example["prompt"]).images @@ -460,15 +473,14 @@ def main(args): if args.externel_unet_path is None: logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) else: logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) - unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, - revision=args.revision, - low_cpu_mem_usage=False) + unet = UNet2DConditionModel.from_pretrained( + args.externel_unet_path, revision=args.revision, low_cpu_mem_usage=False + ) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -482,36 +494,37 @@ def main(args): # Use Booster API to use Gemini/Zero with ColossalAI booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = HybridAdam(unet.parameters(), - lr=args.learning_rate, - initial_scale=2**5, - clipping_norm=args.max_grad_norm) + optimizer = HybridAdam( + unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm + ) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # prepare dataset logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) - train_dataset = DreamBoothDataset(instance_data_root=args.instance_data_dir, - instance_prompt=args.instance_prompt, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_prompt=args.class_prompt, - tokenizer=tokenizer, - size=args.resolution, - center_crop=args.center_crop, - test=args.test_run) + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + test=args.test_run, + ) def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] @@ -527,9 +540,7 @@ def collate_fn(examples): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -541,11 +552,9 @@ def collate_fn(examples): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -664,7 +673,7 @@ def collate_fn(examples): logs = { "loss": loss.detach().item(), "lr": optimizer.param_groups[0]["lr"], - } # lr_scheduler.get_last_lr()[0]} + } # lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step % args.save_steps == 0: diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 654bce36ccb7..ea6dde8bb578 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -28,8 +28,6 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer -from colossalai.zero.gemini import get_static_torch_model disable_existing_loggers() logger = get_dist_logger() @@ -122,8 +120,10 @@ def parse_args(input_args=None): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If there are not enough images already present in" - " class_data_dir, additional images will be sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -136,8 +136,10 @@ def parse_args(input_args=None): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), ) parser.add_argument( "--placement", @@ -149,13 +151,14 @@ def parse_args(input_args=None): "--center_crop", default=False, action="store_true", - help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly" - " cropped. The images will be resized to the resolution first before cropping."), + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -186,16 +189,17 @@ def parse_args(input_args=None): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -206,18 +210,22 @@ def parse_args(input_args=None): default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", @@ -227,7 +235,8 @@ def parse_args(input_args=None): help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" - " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."), + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -293,12 +302,14 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -392,9 +403,9 @@ def main(args): pipeline.to(get_current_device()) for example in tqdm( - sample_dataloader, - desc="Generating class images", - disable=not local_rank == 0, + sample_dataloader, + desc="Generating class images", + disable=not local_rank == 0, ): images = pipeline(example["prompt"]).images @@ -461,19 +472,17 @@ def main(args): if args.externel_unet_path is None: logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) else: logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) - unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, - revision=args.revision, - low_cpu_mem_usage=False) - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + unet = UNet2DConditionModel.from_pretrained( + args.externel_unet_path, revision=args.revision, low_cpu_mem_usage=False + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False + ) unet.requires_grad_(False) # Set correct lora layers @@ -492,7 +501,7 @@ def main(args): lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) unet.set_attn_processor(lora_attn_procs) - lora_layers = AttnProcsLayers(unet.attn_processors) + AttnProcsLayers(unet.attn_processors) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -506,22 +515,21 @@ def main(args): # Use Booster API to use Gemini/Zero with ColossalAI booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = HybridAdam(unet.parameters(), - lr=args.learning_rate, - initial_scale=2**5, - clipping_norm=args.max_grad_norm) + optimizer = HybridAdam( + unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm + ) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") @@ -552,9 +560,7 @@ def collate_fn(examples): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad( - { - "input_ids": input_ids - }, + {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", @@ -566,11 +572,9 @@ def collate_fn(examples): } return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=1) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -689,7 +693,7 @@ def collate_fn(examples): logs = { "loss": loss.detach().item(), "lr": optimizer.param_groups[0]["lr"], - } # lr_scheduler.get_last_lr()[0]} + } # lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step % args.save_steps == 0: diff --git a/examples/images/dreambooth/train_dreambooth_inpaint.py b/examples/images/dreambooth/train_dreambooth_inpaint.py index 774cd4c458e9..32f1b4959879 100644 --- a/examples/images/dreambooth/train_dreambooth_inpaint.py +++ b/examples/images/dreambooth/train_dreambooth_inpaint.py @@ -126,8 +126,10 @@ def parse_args(): "--num_class_images", type=int, default=100, - help=("Minimal class images for prior preservation loss. If not have enough images, additional images will be" - " sampled with class_prompt."), + help=( + "Minimal class images for prior preservation loss. If not have enough images, additional images will be" + " sampled with class_prompt." + ), ) parser.add_argument( "--output_dir", @@ -140,17 +142,18 @@ def parse_args(): "--resolution", type=int, default=512, - help=("The resolution for input images, all the images in the train/validation dataset will be resized to this" - " resolution"), + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) - parser.add_argument("--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution") parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") - parser.add_argument("--train_batch_size", - type=int, - default=4, - help="Batch size (per device) for the training dataloader.") + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.") parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( @@ -186,16 +189,17 @@ def parse_args(): "--lr_scheduler", type=str, default="constant", - help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]'), + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) - parser.add_argument("--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.") - parser.add_argument("--use_8bit_adam", - action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") @@ -213,17 +217,21 @@ def parse_args(): "--logging_dir", type=str, default="logs", - help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."), + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), ) parser.add_argument( "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], - help=("Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU."), + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -283,12 +291,14 @@ def __init__( else: self.class_data_root = None - self.image_transforms = transforms.Compose([ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ]) + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) def __len__(self): return self._length @@ -369,7 +379,8 @@ def main(): if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: raise ValueError( "Gradient accumulation is not supported when training the text encoder in distributed training. " - "Please set gradient_accumulation_steps to 1. This feature will be supported in the future.") + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) if args.seed is not None: set_seed(args.seed) @@ -382,25 +393,25 @@ def main(): if cur_class_images < args.num_class_images: torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 - pipeline = StableDiffusionInpaintPipeline.from_pretrained(args.pretrained_model_name_or_path, - torch_dtype=torch_dtype, - safety_checker=None) + pipeline = StableDiffusionInpaintPipeline.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None + ) pipeline.set_progress_bar_config(disable=True) num_new_images = args.num_class_images - cur_class_images logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, - batch_size=args.sample_batch_size, - num_workers=1) + sample_dataloader = torch.utils.data.DataLoader( + sample_dataset, batch_size=args.sample_batch_size, num_workers=1 + ) sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) transform_to_pil = transforms.ToPILImage() - for example in tqdm(sample_dataloader, - desc="Generating class images", - disable=not accelerator.is_local_main_process): + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): bsz = len(example["prompt"]) fake_images = torch.rand((3, args.resolution, args.resolution)) transform_to_pil = transforms.ToPILImage() @@ -457,8 +468,9 @@ def main(): text_encoder.gradient_checkpointing_enable() if args.scale_lr: - args.learning_rate = (args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * - accelerator.num_processes) + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: @@ -471,8 +483,9 @@ def main(): else: optimizer_class = torch.optim.AdamW - params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, @@ -494,10 +507,12 @@ def main(): ) def collate_fn(examples): - image_transforms = transforms.Compose([ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), - ]) + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + ] + ) input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -545,10 +560,9 @@ def collate_fn(examples): batch = {"input_ids": input_ids, "pixel_values": pixel_values, "masks": masks, "masked_images": masked_images} return batch - train_dataloader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=collate_fn) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn + ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -566,10 +580,12 @@ def collate_fn(examples): if args.train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, - lr_scheduler) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": @@ -622,16 +638,19 @@ def collate_fn(examples): latents = latents * 0.18215 # Convert masked images to latent space - masked_latents = vae.encode(batch["masked_images"].reshape( - batch["pixel_values"].shape).to(dtype=weight_dtype)).latent_dist.sample() + masked_latents = vae.encode( + batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype) + ).latent_dist.sample() masked_latents = masked_latents * 0.18215 masks = batch["masks"] # resize the mask to latents shape as we concatenate the mask to the latents - mask = torch.stack([ - torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) - for mask in masks - ]) + mask = torch.stack( + [ + torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) + for mask in masks + ] + ) mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) # Sample noise that we'll add to the latents @@ -680,8 +699,11 @@ def collate_fn(examples): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder else unet.parameters()) + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() diff --git a/examples/images/resnet/eval.py b/examples/images/resnet/eval.py index 657708ec3ff2..526e41a2850f 100644 --- a/examples/images/resnet/eval.py +++ b/examples/images/resnet/eval.py @@ -1,7 +1,6 @@ import argparse import torch -import torch.nn as nn import torchvision import torchvision.transforms as transforms @@ -9,15 +8,15 @@ # Parse Arguments # ============================== parser = argparse.ArgumentParser() -parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint") -parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") +parser.add_argument("-e", "--epoch", type=int, default=80, help="resume from the epoch's checkpoint") +parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") args = parser.parse_args() # ============================== # Prepare Test Dataset # ============================== # CIFAR-10 dataset -test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor()) +test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transforms.ToTensor()) # Data loader test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) @@ -26,7 +25,7 @@ # Load Model # ============================== model = torchvision.models.resnet18(num_classes=10).cuda() -state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth') +state_dict = torch.load(f"{args.checkpoint}/model_{args.epoch}.pth") model.load_state_dict(state_dict) # ============================== @@ -45,4 +44,4 @@ total += labels.size(0) correct += (predicted == labels).sum().item() - print('Accuracy of the model on the test images: {} %'.format(100 * correct / total)) + print("Accuracy of the model on the test images: {} %".format(100 * correct / total)) diff --git a/examples/images/resnet/requirements.txt b/examples/images/resnet/requirements.txt index 3c7da7743702..46b7da7d4870 100644 --- a/examples/images/resnet/requirements.txt +++ b/examples/images/resnet/requirements.txt @@ -2,4 +2,4 @@ colossalai torch torchvision tqdm -pytest \ No newline at end of file +pytest diff --git a/examples/images/resnet/train.py b/examples/images/resnet/train.py index fa300395c9f3..13df516d4189 100644 --- a/examples/images/resnet/train.py +++ b/examples/images/resnet/train.py @@ -30,23 +30,19 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): # transform transform_train = transforms.Compose( - [transforms.Pad(4), - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32), - transforms.ToTensor()]) + [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()] + ) transform_test = transforms.ToTensor() # CIFAR-10 dataset - data_path = os.environ.get('DATA', './data') + data_path = os.environ.get("DATA", "./data") with coordinator.priority_execution(): - train_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=True, - transform=transform_train, - download=True) - test_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=False, - transform=transform_test, - download=True) + train_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=True, transform=transform_train, download=True + ) + test_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=False, transform=transform_test, download=True + ) # Data loader train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) @@ -70,14 +66,21 @@ def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoo dist.all_reduce(total) accuracy = correct.item() / total.item() if coordinator.is_master(): - print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %') + print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %") return accuracy -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: nn.Module, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: for images, labels in pbar: images = images.cuda() labels = labels.cuda() @@ -91,7 +94,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: n optimizer.zero_grad() # Print log info - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) def main(): @@ -100,19 +103,20 @@ def main(): # ============================== parser = argparse.ArgumentParser() # FIXME(ver217): gemini is not supported resnet now - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero', 'gemini'], - help="plugin to use") - parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") - parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") - parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") - parser.add_argument('--target_acc', - type=float, - default=None, - help="target accuracy. Raise exception if not reached") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "low_level_zero", "gemini"], + help="plugin to use", + ) + parser.add_argument("-r", "--resume", type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") + parser.add_argument("-i", "--interval", type=int, default=5, help="interval of saving checkpoint") + parser.add_argument( + "--target_acc", type=float, default=None, help="target accuracy. Raise exception if not reached" + ) args = parser.parse_args() # ============================== @@ -136,13 +140,13 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -168,18 +172,17 @@ def main(): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, criterion, _, lr_scheduler = booster.boost(model, - optimizer, - criterion=criterion, - lr_scheduler=lr_scheduler) + model, optimizer, criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=criterion, lr_scheduler=lr_scheduler + ) # ============================== # Resume from checkpoint # ============================== if args.resume >= 0: - booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') - booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') - booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + booster.load_model(model, f"{args.checkpoint}/model_{args.resume}.pth") + booster.load_optimizer(optimizer, f"{args.checkpoint}/optimizer_{args.resume}.pth") + booster.load_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{args.resume}.pth") # ============================== # Train model @@ -191,14 +194,14 @@ def main(): # save checkpoint if args.interval > 0 and (epoch + 1) % args.interval == 0: - booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') - booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') - booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') + booster.save_model(model, f"{args.checkpoint}/model_{epoch + 1}.pth") + booster.save_optimizer(optimizer, f"{args.checkpoint}/optimizer_{epoch + 1}.pth") + booster.save_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{epoch + 1}.pth") accuracy = evaluate(model, test_dataloader, coordinator) if args.target_acc is not None: - assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}' + assert accuracy >= args.target_acc, f"Accuracy {accuracy} is lower than target accuracy {args.target_acc}" -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py index e6c52c4e97fd..7d54020f85c4 100644 --- a/examples/images/vit/args.py +++ b/examples/images/vit/args.py @@ -2,44 +2,47 @@ def parse_demo_args(): - parser = get_default_parser() - parser.add_argument("--model_name_or_path", - type=str, - default="google/vit-base-patch16-224", - help="Path to pretrained model or model identifier from huggingface.co/models.") - parser.add_argument("--output_path", - type=str, - default="./output_model", - help="The path of your saved model after finetuning.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_path", type=str, default="./output_model", help="The path of your saved model after finetuning." + ) parser.add_argument( "--plugin", type=str, default="gemini", - help= - "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'.", ) parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.") - parser.add_argument("--batch_size", - type=int, - default=32, - help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--tp_size", - type=int, - default=1, - help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.") - parser.add_argument("--pp_size", - type=int, - default=1, - help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.") - parser.add_argument("--learning_rate", - type=float, - default=3e-4, - help="Initial learning rate (after the potential warmup period) to use.") - parser.add_argument("--warmup_ratio", - type=float, - default=0.3, - help="Ratio of warmup steps against total training steps.") + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--tp_size", + type=int, + default=1, + help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.", + ) + parser.add_argument( + "--pp_size", + type=int, + default=1, + help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=3e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--warmup_ratio", type=float, default=0.3, help="Ratio of warmup steps against total training steps." + ) parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") @@ -49,29 +52,30 @@ def parse_demo_args(): def parse_benchmark_args(): - parser = get_default_parser() - parser.add_argument("--model_name_or_path", - type=str, - default="google/vit-base-patch16-224", - help="Path to a pretrained model or model identifier from huggingface.co/models.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to a pretrained model or model identifier from huggingface.co/models.", + ) parser.add_argument( "--plugin", type=str, default="gemini", - help= - "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'.", + ) + parser.add_argument( + "--batch_size", type=int, default=8, help="Batch size (per dp group) for the training dataloader." ) - parser.add_argument("--batch_size", - type=int, - default=8, - help="Batch size (per dp group) for the training dataloader.") parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.") - parser.add_argument("--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") diff --git a/examples/images/vit/data.py b/examples/images/vit/data.py index 77a8ad525056..5361fe9a3bad 100644 --- a/examples/images/vit/data.py +++ b/examples/images/vit/data.py @@ -4,13 +4,11 @@ class BeansDataset(Dataset): - - def __init__(self, image_processor, tp_size=1, split='train'): - + def __init__(self, image_processor, tp_size=1, split="train"): super().__init__() self.image_processor = image_processor - self.ds = load_dataset('beans')[split] - self.label_names = self.ds.features['labels'].names + self.ds = load_dataset("beans")[split] + self.label_names = self.ds.features["labels"].names while len(self.label_names) % tp_size != 0: # ensure that the number of labels is multiple of tp_size self.label_names.append(f"pad_label_{len(self.label_names)}") @@ -26,13 +24,13 @@ def __getitem__(self, idx): return self.inputs[idx] def process_example(self, example): - input = self.image_processor(example['image'], return_tensors='pt') - input['labels'] = example['labels'] + input = self.image_processor(example["image"], return_tensors="pt") + input["labels"] = example["labels"] return input def beans_collator(batch): return { - 'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), - 'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64) + "pixel_values": torch.cat([data["pixel_values"] for data in batch], dim=0), + "labels": torch.tensor([data["labels"] for data in batch], dtype=torch.int64), } diff --git a/examples/images/vit/requirements.txt b/examples/images/vit/requirements.txt index edad87ca380f..69e41c61cd67 100644 --- a/examples/images/vit/requirements.txt +++ b/examples/images/vit/requirements.txt @@ -3,4 +3,4 @@ torch >= 1.8.1 numpy>=1.24.1 tqdm>=4.61.2 transformers>=4.20.0 -datasets \ No newline at end of file +datasets diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index d822fe23ecf0..b770bc9cfb95 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -25,18 +25,16 @@ def format_num(num: int, bytes=False): def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224): - pixel_values = torch.randn(batch_size, - num_channels, - height, - width, - device=torch.cuda.current_device(), - dtype=torch.float) + pixel_values = torch.randn( + batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float + ) labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64) return dict(pixel_values=pixel_values, labels=labels) def colo_memory_cap(size_in_GB): from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) @@ -44,7 +42,6 @@ def colo_memory_cap(size_in_GB): def main(): - args = parse_benchmark_args() # Launch ColossalAI @@ -75,22 +72,24 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': - plugin = HybridParallelPlugin(tp_size=2, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_all_optimization=True, - precision='fp16', - initial_scale=1) + elif args.plugin == "hybrid_parallel": + plugin = HybridParallelPlugin( + tp_size=2, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision="fp16", + initial_scale=1, + ) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Set optimizer @@ -119,12 +118,9 @@ def criterion(outputs, inputs): if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: # run pipeline forward backward batch = iter([batch]) - outputs = booster.execute_pipeline(batch, - model, - criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + batch, model, criterion, optimizer, return_loss=True, return_outputs=True + ) else: outputs = model(**batch) loss = criterion(outputs, None) @@ -146,7 +142,8 @@ def criterion(outputs, inputs): f"plugin: {args.plugin}, " f"throughput: {throughput}, " f"maximum memory usage per gpu: {max_mem}.", - ranks=[0]) + ranks=[0], + ) torch.cuda.empty_cache() diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 206d8694b8f5..81009b3707b6 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -25,19 +25,21 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} -def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], - data_iter: Iterator, booster: Booster): +def run_forward_backward( + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + data_iter: Iterator, + booster: Booster, +): if optimizer is not None: optimizer.zero_grad() if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: # run pipeline forward backward when enabling pp in hybrid parallel plugin - output_dict = booster.execute_pipeline(data_iter, - model, - criterion, - optimizer, - return_loss=True, - return_outputs=True) - loss, outputs = output_dict['loss'], output_dict['outputs'] + output_dict = booster.execute_pipeline( + data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True + ) + loss, outputs = output_dict["loss"], output_dict["outputs"] else: batch = next(data_iter) batch = move_to_cuda(batch, torch.cuda.current_device()) @@ -49,9 +51,16 @@ def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Call return loss, outputs -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], - lr_scheduler: LRScheduler, dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + lr_scheduler: LRScheduler, + dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): torch.cuda.synchronize() num_steps = len(dataloader) @@ -61,12 +70,11 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C # when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar tp_rank = dist.get_rank(booster.plugin.tp_group) dp_rank = dist.get_rank(booster.plugin.dp_group) - enable_pbar = tp_rank == 0 and dp_rank == 0 \ - and booster.plugin.stage_manager.is_last_stage() + enable_pbar = tp_rank == 0 and dp_rank == 0 and booster.plugin.stage_manager.is_last_stage() model.train() - with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar: + with tqdm(range(num_steps), desc=f"Epoch [{epoch + 1}]", disable=not enable_pbar) as pbar: for _ in pbar: loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) optimizer.step() @@ -74,13 +82,18 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C # Print batch loss if enable_pbar: - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) @torch.no_grad() -def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor], - eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - +def evaluate_model( + epoch: int, + model: nn.Module, + criterion: Callable[[Any, Any], torch.Tensor], + eval_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): torch.cuda.synchronize() model.eval() accum_loss = torch.zeros(1, device=torch.cuda.current_device()) @@ -99,13 +112,13 @@ def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], to_accum = to_accum and booster.plugin.stage_manager.is_last_stage() if to_accum: - accum_loss += (loss / len(eval_dataloader)) + accum_loss += loss / len(eval_dataloader) logits = outputs["logits"] preds = torch.argmax(logits, dim=1) labels = batch["labels"] total_num += batch["labels"].shape[0] - accum_correct += (torch.sum(preds == labels)) + accum_correct += torch.sum(preds == labels) dist.all_reduce(accum_loss) dist.all_reduce(total_num) @@ -113,13 +126,14 @@ def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], avg_loss = "{:.4f}".format(accum_loss.item()) accuracy = "{:.4f}".format(accum_correct.item() / total_num.item()) if coordinator.is_master(): - print(f"Evaluation result for epoch {epoch + 1}: \ + print( + f"Evaluation result for epoch {epoch + 1}: \ average_loss={avg_loss}, \ - accuracy={accuracy}.") + accuracy={accuracy}." + ) def main(): - args = parse_demo_args() # Launch ColossalAI @@ -136,14 +150,14 @@ def main(): transformers.utils.logging.set_verbosity_error() # Reset tp_size and pp_size to 1 if not using hybrid parallel. - if args.plugin != 'hybrid_parallel': + if args.plugin != "hybrid_parallel": args.tp_size = 1 args.pp_size = 1 # Prepare Dataset image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path) - train_dataset = BeansDataset(image_processor, args.tp_size, split='train') - eval_dataset = BeansDataset(image_processor, args.tp_size, split='validation') + train_dataset = BeansDataset(image_processor, args.tp_size, split="train") + eval_dataset = BeansDataset(image_processor, args.tp_size, split="validation") num_labels = train_dataset.num_labels # Load pretrained ViT model @@ -151,9 +165,9 @@ def main(): config.num_labels = num_labels config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} - model = ViTForImageClassification.from_pretrained(args.model_name_or_path, - config=config, - ignore_mismatched_sizes=True) + model = ViTForImageClassification.from_pretrained( + args.model_name_or_path, config=config, ignore_mismatched_sizes=True + ) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) # Enable gradient checkpointing @@ -162,37 +176,35 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': - plugin = HybridParallelPlugin(tp_size=args.tp_size, - pp_size=args.pp_size, - num_microbatches=None, - microbatch_size=1, - enable_all_optimization=True, - precision='fp16', - initial_scale=1) + elif args.plugin == "hybrid_parallel": + plugin = HybridParallelPlugin( + tp_size=args.tp_size, + pp_size=args.pp_size, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision="fp16", + initial_scale=1, + ) else: raise ValueError(f"Plugin with name {args.plugin} is not supported!") logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Prepare dataloader - train_dataloader = plugin.prepare_dataloader(train_dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=beans_collator) - eval_dataloader = plugin.prepare_dataloader(eval_dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=beans_collator) + train_dataloader = plugin.prepare_dataloader( + train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=beans_collator + ) + eval_dataloader = plugin.prepare_dataloader( + eval_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=beans_collator + ) # Set optimizer optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) @@ -204,17 +216,15 @@ def criterion(outputs, inputs): # Set lr scheduler total_steps = len(train_dataloader) * args.num_epoch num_warmup_steps = int(args.warmup_ratio * total_steps) - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=(len(train_dataloader) * args.num_epoch), - warmup_steps=num_warmup_steps) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=(len(train_dataloader) * args.num_epoch), warmup_steps=num_warmup_steps + ) # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - criterion=criterion, - dataloader=train_dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost( + model=model, optimizer=optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) # Finetuning logger.info(f"Start finetuning", ranks=[0]) diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py index 67ff13bb5f5e..738f43dc0619 100644 --- a/examples/inference/bench_bloom.py +++ b/examples/inference/bench_bloom.py @@ -11,7 +11,7 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def print_perf_stats(latency_set, config, bs, warmup=3): @@ -25,7 +25,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3): avg = sum(latency_set) / count num_layers = getattr(config, "num_layers", config.num_hidden_layers) num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 + num_bytes = 2 # float16 print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) @@ -53,7 +53,7 @@ def bench_bloom(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), - "attention_mask": torch.ones((max_batch_size, max_input_len)) + "attention_mask": torch.ones((max_batch_size, max_input_len)), } for t in input_tokens: if torch.is_tensor(input_tokens[t]): @@ -77,7 +77,7 @@ def bench_bloom(args): def check_bloom(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") bench_bloom(args) @@ -89,11 +89,11 @@ def test_bloom(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') - parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') - parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') - parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") args = parser.parse_args() diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index d2016a4587e6..6e49fa80c812 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -12,7 +12,7 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def init_to_get_rotary(self, base=10000): @@ -28,8 +28,9 @@ def init_to_get_rotary(self, base=10000): else: max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / - self.config.head_dim_)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_) + ) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -75,8 +76,8 @@ def run_llama_test(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), - "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } iters = 10 @@ -105,7 +106,7 @@ def run_llama_test(args): def check_llama(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_test(args) @@ -117,11 +118,11 @@ def test_llama(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') - parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') - parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') - parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") args = parser.parse_args() diff --git a/examples/language/bert/benchmark.py b/examples/language/bert/benchmark.py index ae8b2269a534..10bd367fda5b 100644 --- a/examples/language/bert/benchmark.py +++ b/examples/language/bert/benchmark.py @@ -32,9 +32,7 @@ class RandintDataset(Dataset): - def __init__(self, dataset_length: int, sequence_length: int, vocab_size: int, n_class: int): - self._sequence_length = sequence_length self._vocab_size = vocab_size self._n_class = n_class @@ -42,10 +40,13 @@ def __init__(self, dataset_length: int, sequence_length: int, vocab_size: int, n self._datas = torch.randint( low=0, high=self._vocab_size, - size=(self._dataset_length, self._sequence_length,), + size=( + self._dataset_length, + self._sequence_length, + ), dtype=torch.long, ) - self._labels = torch.randint(low=0, high=self._n_class, size=(self._dataset_length, 1), dtype=torch.long) + self._labels = torch.randint(low=0, high=self._n_class, size=(self._dataset_length, 1), dtype=torch.long) def __len__(self): return self._dataset_length @@ -59,13 +60,15 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--model_type", type=str, @@ -88,13 +91,13 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -103,10 +106,9 @@ def main(): # Prepare Dataloader # ============================== - train_dataset = RandintDataset(dataset_length=DATASET_LEN, - sequence_length=SEQ_LEN, - vocab_size=VOCAB_SIZE, - n_class=NUM_LABELS) + train_dataset = RandintDataset( + dataset_length=DATASET_LEN, sequence_length=SEQ_LEN, vocab_size=VOCAB_SIZE, n_class=NUM_LABELS + ) train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE) # ==================================== @@ -159,16 +161,12 @@ def main(): # Benchmark model # ============================== - results = benchmark(model, - booster, - optimizer, - lr_scheduler, - train_dataloader, - criterion=criterion, - epoch_num=NUM_EPOCHS) + results = benchmark( + model, booster, optimizer, lr_scheduler, train_dataloader, criterion=criterion, epoch_num=NUM_EPOCHS + ) coordinator.print_on_master(results) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/bert/benchmark_utils.py b/examples/language/bert/benchmark_utils.py index 886017a41826..04d55cb2e7b6 100644 --- a/examples/language/bert/benchmark_utils.py +++ b/examples/language/bert/benchmark_utils.py @@ -112,8 +112,9 @@ def benchmark( start_time = time() for epoch in range(epoch_num): - with tqdm(dataloader, desc=f'Epoch [{epoch + 1}/{epoch_num}]', - disable=not DistCoordinator().is_master()) as pbar: + with tqdm( + dataloader, desc=f"Epoch [{epoch + 1}/{epoch_num}]", disable=not DistCoordinator().is_master() + ) as pbar: for data in pbar: inputs, labels = data[0].cuda(), data[1].cuda() outputs = model(inputs, labels=labels) @@ -137,7 +138,9 @@ def benchmark( } logger.info(fmt({f"Memory results (batch_size={batch_size})": memory[f"batch_size_{batch_size}"]})) - throughput[f"batch_size_{batch_size}"] = {"throughput:": "{:.1f}".format(all_sample * DistCoordinator().world_size / (end_time - start_time))} + throughput[f"batch_size_{batch_size}"] = { + "throughput:": "{:.1f}".format(all_sample * DistCoordinator().world_size / (end_time - start_time)) + } logger.info(fmt({f"Throughput results (batch_size={batch_size})": throughput[f"batch_size_{batch_size}"]})) results["throughput"] = throughput diff --git a/examples/language/bert/data.py b/examples/language/bert/data.py index 981cedcca8c2..ef51f938dc4f 100644 --- a/examples/language/bert/data.py +++ b/examples/language/bert/data.py @@ -5,7 +5,6 @@ class GLUEDataBuilder: - task_text_field_map = { "cola": ["sentence"], "sst2": ["sentence"], @@ -84,10 +83,9 @@ def prepare_data(self): AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) def train_dataloader(self): - return self.plugin.prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) def val_dataloader(self): if len(self.eval_splits) == 1: @@ -108,7 +106,6 @@ def test_dataloader(self): ] def convert_to_features(self, example_batch): - # Either encode single sentence or sentence pairs if len(self.text_fields) > 1: texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) @@ -116,10 +113,9 @@ def convert_to_features(self, example_batch): texts_or_text_pairs = example_batch[self.text_fields[0]] # Tokenize the text/text pairs - features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, - max_length=self.max_seq_length, - padding='max_length', - truncation=True) + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) # Rename label to labels to make it easier to pass to model forward features["labels"] = example_batch["label"] diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index fb6e4332c2f9..563cfa58d5f6 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -1,5 +1,4 @@ import argparse -from contextlib import nullcontext from typing import Callable, List, Union import evaluate @@ -7,7 +6,7 @@ import torch.distributed as dist import torch.nn as nn from data import GLUEDataBuilder -from torch.optim import Adam, Optimizer +from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm @@ -22,7 +21,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -109,7 +107,7 @@ def evaluate_subset(dataloader: DataLoader): results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) if coordinator.is_master() and results is not None: - results['loss'] = accum_loss.item() / coordinator.world_size + results["loss"] = accum_loss.item() / coordinator.world_size return results @@ -120,13 +118,20 @@ def evaluate_subset(dataloader: DataLoader): final_results = {} for split, sub_loader in zip(eval_splits, test_dataloader): results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, - train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + _criterion: Callable, + lr_scheduler: LRScheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) @@ -135,20 +140,17 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: model.train() optimizer.zero_grad() train_dataloader_iter = iter(train_dataloader) - with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not print_flag) as pbar: + with tqdm(range(total_step), desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not print_flag) as pbar: # Forward pass for _ in pbar: if use_pipeline: - outputs = booster.execute_pipeline(train_dataloader_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) # Backward and optimize if is_pp_last_stage: - loss = outputs['loss'] - pbar.set_postfix({'loss': loss.item()}) + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) else: data = next(train_dataloader_iter) data = move_to_cuda(data) @@ -156,7 +158,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: loss = _criterion(outputs, None) # Backward booster.backward(loss, optimizer) - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) optimizer.step() optimizer.zero_grad() @@ -168,26 +170,28 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], - help="plugin to use") + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], + help="plugin to use", + ) parser.add_argument( "--model_type", type=str, default="bert", help="bert or albert", ) - parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") - parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() - if args.model_type == 'bert': + if args.model_type == "bert": model_name = "bert-base-uncased" - elif args.model_type == 'albert': + elif args.model_type == "albert": model_name = "albert-xxlarge-v2" else: raise RuntimeError @@ -204,36 +208,35 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': - + elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases - plugin = HybridParallelPlugin(tp_size=1, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_all_optimization=True, - zero_stage=1, - precision='fp16', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) booster = Booster(plugin=plugin, **booster_kwargs) # ============================== # Prepare Dataloader # ============================== - data_builder = GLUEDataBuilder(model_name, - plugin, - args.task, - train_batch_size=BATCH_SIZE, - eval_batch_size=BATCH_SIZE) + data_builder = GLUEDataBuilder( + model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE + ) train_dataloader = data_builder.train_dataloader() test_dataloader = data_builder.test_dataloader() @@ -283,10 +286,9 @@ def _criterion(outputs, inputs): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, - optimizer, - criterion=_criterion, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler + ) # ============================== # Train model @@ -294,14 +296,22 @@ def _criterion(outputs, inputs): for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task, - data_builder.eval_splits, booster, coordinator) + results = evaluate_model( + model, + _criterion, + test_dataloader, + data_builder.num_labels, + args.task, + data_builder.eval_splits, + booster, + coordinator, + ) if coordinator.is_master(): print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/experiments/auto_offload/model_zoo.py b/examples/language/gpt/experiments/auto_offload/model_zoo.py index 35e44608f810..75968a0b1da9 100644 --- a/examples/language/gpt/experiments/auto_offload/model_zoo.py +++ b/examples/language/gpt/experiments/auto_offload/model_zoo.py @@ -2,22 +2,20 @@ import torch.nn as nn from transformers import GPT2Config, GPT2LMHeadModel -class GPTLMModel(nn.Module): - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257): +class GPTLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): super().__init__() self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) def forward(self, input_ids, attention_mask): # Only return lm_logits @@ -25,7 +23,6 @@ def forward(self, input_ids, attention_mask): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -36,6 +33,7 @@ def forward(self, logits, labels): # Flatten the tokens return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + def get_gpt2_components(model_type: str, batch_size: int): vocab_size = 1024 seq_len = 8 @@ -62,4 +60,4 @@ def gpt2_data_gen(device="cuda"): kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs - return gpt2_model_builder, gpt2_data_gen \ No newline at end of file + return gpt2_model_builder, gpt2_data_gen diff --git a/examples/language/gpt/experiments/auto_offload/requirements.txt b/examples/language/gpt/experiments/auto_offload/requirements.txt index 3ebde8d460aa..137a69e80498 100644 --- a/examples/language/gpt/experiments/auto_offload/requirements.txt +++ b/examples/language/gpt/experiments/auto_offload/requirements.txt @@ -1,2 +1,2 @@ colossalai >= 0.1.12 -torch >= 1.8.1 \ No newline at end of file +torch >= 1.8.1 diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index 89415c23f93c..521527da51e0 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -3,7 +3,6 @@ import pytest import torch -from model_zoo import GPTLMLoss, get_gpt2_components from torch.utils._pytree import tree_map import colossalai @@ -14,18 +13,19 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.testing import spawn from colossalai.utils import get_current_device +from model_zoo import GPTLMLoss, get_gpt2_components def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--model_type', type=str, default="gpt2_medium") - parser.add_argument('--batch_size', type=int, default=64) - parser.add_argument('--solver_type', type=str, default='asyn') - parser.add_argument('--memory_budget', type=float, default=16) + parser.add_argument("--model_type", type=str, default="gpt2_medium") + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--solver_type", type=str, default="asyn") + parser.add_argument("--memory_budget", type=float, default=16) return parser.parse_args() -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed") def train_gpt(args): memory_budget = args.memory_budget * 1024 * 1024 * 1024 solver_type = args.solver_type @@ -34,10 +34,15 @@ def train_gpt(args): # build model model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) - label = torch.randint(low=0, high=128, size=( - 64, - 8, - ), device=get_current_device()) + label = torch.randint( + low=0, + high=128, + size=( + 64, + 8, + ), + device=get_current_device(), + ) criterion = GPTLMLoss() start_time = time.time() @@ -80,18 +85,20 @@ def train_gpt(args): exec_time = sum(sorted(time_list)[:5]) / 5 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 - print(f'solver_type: {solver_type} | model_type: {model_type}') - print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(f"solver_type: {solver_type} | model_type: {model_type}") + print( + f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB " + f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|" + ) print(time_list) def run(rank, world_size, port, args): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") train_gpt(args) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() spawn(run, 1, args=args) diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py index 84b02633e775..f3d35dd9042b 100644 --- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -29,8 +29,8 @@ def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_mem_info(prefix=""): + return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_tflops(model_numel, batch_size, seq_len, step_time): @@ -51,14 +51,14 @@ def main(): logger = get_dist_logger() config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM) if FP16: - model = GPT2LMHeadModel(config=config).half().to('cuda') + model = GPT2LMHeadModel(config=config).half().to("cuda") else: - model = GPT2LMHeadModel(config=config).to('cuda') + model = GPT2LMHeadModel(config=config).to("cuda") global_numel = sum([p.numel() for p in model.parameters()]) meta_input_sample = { - 'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), - 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'), + "input_ids": torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to("meta"), + "attention_mask": torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to("meta"), } gm, solution = autoparallelize(model, meta_input_sample, return_solution=True) @@ -72,7 +72,7 @@ def main(): criterion = GPTLMLoss() optimizer = torch.optim.Adam(gm.parameters(), lr=0.01) - logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + logger.info(get_mem_info(prefix="After init model, "), ranks=[0]) get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH) torch.cuda.synchronize() model.train() @@ -89,10 +89,11 @@ def main(): torch.cuda.synchronize() step_time = time() - start logger.info( - f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', - ranks=[0]) + f"[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}", + ranks=[0], + ) torch.cuda.synchronize() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/experiments/auto_parallel/gpt_modules.py b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py index 95feaec38c26..ad9a19777284 100644 --- a/examples/language/gpt/experiments/auto_parallel/gpt_modules.py +++ b/examples/language/gpt/experiments/auto_parallel/gpt_modules.py @@ -8,7 +8,6 @@ class GPT2MLP(nn.Module): - def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size @@ -30,15 +29,15 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl # 2. The order of split and view op has been changed in the customized GPT2Attention class, the new # order is same as megatron-lm gpt model. class GPT2Attention(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) @@ -64,7 +63,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (value.size(-1)**0.5) + attn_weights = attn_weights / (value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -72,7 +71,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: @@ -93,7 +92,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): tensor = tensor.permute(0, 2, 1, 3).contiguous() @@ -106,10 +105,9 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - qkv = self.c_attn(hidden_states) query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) - present = (key, value) + (key, value) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) @@ -117,7 +115,6 @@ def forward( class GPT2Block(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -152,7 +149,6 @@ def forward( class GPT2Model(GPT2PreTrainedModel): - def __init__(self, config): super().__init__(config) @@ -189,11 +185,9 @@ def forward( # GPT2Attention mask. attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 - encoder_attention_mask = None - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -217,7 +211,6 @@ def forward( class GPT2LMHeadModel(GPT2PreTrainedModel): - def __init__(self, config): super().__init__(config) self.transformer = GPT2Model(config) @@ -241,7 +234,6 @@ def forward( class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() diff --git a/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py index c31b3fa6d103..47cc87980556 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py +++ b/examples/language/gpt/experiments/pipeline_parallel/model_zoo.py @@ -4,22 +4,25 @@ ## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint - self.config = GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size) + self.config = GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) self.model = GPT2LMHeadModel(self.config) if checkpoint: self.model.gradient_checkpointing_enable() @@ -70,4 +73,4 @@ def model_builder(model_size: str) -> callable: raise TypeError(f"model_builder {model_size}") -__all__ = ['model_builder'] +__all__ = ["model_builder"] diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py index 749243e57836..17692e90a03c 100644 --- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py +++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py @@ -3,41 +3,34 @@ from functools import partial import torch -from model_zoo import model_builder from torch import nn -from tqdm import tqdm from colossalai.fx import ColoTracer -from colossalai.fx.passes.adding_split_node_pass import ( - avgnode_split_pass, - gpipe_dp_split_pass, - split_with_split_nodes_pass, -) +from colossalai.fx.passes.adding_split_node_pass import gpipe_dp_split_pass, split_with_split_nodes_pass from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology -from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine from colossalai.legacy.pipeline.rpc.utils import rpc_run from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam +from model_zoo import model_builder def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--model_type', type=str, default="gpt2_medium") - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--dp_degree', type=int, default=1) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29011') - parser.add_argument('--num_worker_threads', type=int, default=128) + parser.add_argument("--model_type", type=str, default="gpt2_medium") + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--num_microbatches", type=int, default=2) + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29011") + parser.add_argument("--num_worker_threads", type=int, default=128) return parser.parse_args() class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -63,16 +56,16 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): # Create annotated model which is noted where to be splitted. def get_annotated_model(model, data_kwargs, num_stages, num_microbatches): tracer = ColoTracer() - meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + meta_args = {k: v.to("meta") for k, v in data_kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()]) + interp_meta_args = tuple([v.to("meta") for k, v in data_kwargs.items()]) interp = MetaInfoProp(gm) interp.run(*interp_meta_args) - #annotated_model = avgnode_split_pass(gm, num_stages) - annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01) + # annotated_model = avgnode_split_pass(gm, num_stages) + annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode="block", block_limit=0.01) return annotated_model @@ -83,7 +76,7 @@ def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, n topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) + setattr(submodule, "_topo", topo) return split_submodules[pp_rank + 1] @@ -107,8 +100,10 @@ def run_master(args): disable_existing_loggers() logger = get_dist_logger() - logger.info(f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}", - ranks=[0]) + logger.info( + f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}", + ranks=[0], + ) torch.manual_seed(123) @@ -117,26 +112,28 @@ def run_master(args): # warm up pipeline fx partition input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) - warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask} + warmup_data_kwargs = {"input_ids": input_ids, "attention_mask": attn_mask} # create model - logger.info(f'start model_builder') + logger.info(f"start model_builder") model = model_builder(model_type)(checkpoint=False) - logger.info(f'end model_builder') + logger.info(f"end model_builder") # set 1f1b pipeline engine - pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches), - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=1, - criterion=criterion, - metric=None, - checkpoint=False) + pp_engine = FillDrainPipelineEngine( + partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=1, + criterion=criterion, + metric=None, + checkpoint=False, + ) partition_numels = pp_engine.remote_numels() for rank, numel in partition_numels.items(): - logger.info(f'{rank=} numel in the partition:{numel}') + logger.info(f"{rank=} numel in the partition:{numel}") # build optim pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) @@ -145,7 +142,7 @@ def run_master(args): for n in range(NUM_STEPS): # we just use randomly generated data here input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) - batch = {'input_ids': input_ids, 'attention_mask': attn_mask} + batch = {"input_ids": input_ids, "attention_mask": attn_mask} start = time.time() outputs = pp_engine.forward_backward(batch=batch, labels=input_ids, forward_only=False) @@ -175,6 +172,6 @@ def run_master(args): logger.info(f"Avg TFLOPS per GPU is {sum(gpu_tflops) / world_size:.3f}") -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() rpc_run(args, run_master) diff --git a/examples/language/gpt/gemini/commons/model_zoo.py b/examples/language/gpt/gemini/commons/model_zoo.py index 65124d9e4884..0f4517549db2 100644 --- a/examples/language/gpt/gemini/commons/model_zoo.py +++ b/examples/language/gpt/gemini/commons/model_zoo.py @@ -4,22 +4,25 @@ ## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint - self.config = GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size) + self.config = GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) self.model = GPT2LMHeadModel(self.config) if checkpoint: self.model.gradient_checkpointing_enable() @@ -82,4 +85,4 @@ def model_builder(model_size: str) -> callable: raise TypeError(f"model_builder {model_size}") -__all__ = ['model_builder'] +__all__ = ["model_builder"] diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py index 7bd098c1927c..7ed5fdb92b35 100644 --- a/examples/language/gpt/gemini/commons/utils.py +++ b/examples/language/gpt/gemini/commons/utils.py @@ -6,7 +6,6 @@ class DummyProfiler: - def __init__(self): self.step_number = 0 @@ -27,11 +26,13 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): if enable_flag: - return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), - on_trace_ready=tensorboard_trace_handler(save_dir), - record_shapes=True, - profile_memory=True) + return profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True, + ) else: return nullcontext(DummyProfiler()) diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index f9d30fd15c7b..88b76c654b1d 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -27,7 +27,7 @@ def parse_args(): parser.add_argument( "--distplan", type=str, - default='CAI_Gemini', + default="CAI_Gemini", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", ) parser.add_argument( @@ -54,7 +54,6 @@ def parse_args(): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -74,8 +73,8 @@ def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 -def get_mem_info(prefix=''): - return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' +def get_mem_info(prefix=""): + return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_model_size(model: nn.Module): @@ -91,11 +90,11 @@ def model_size_formatter(numel: int) -> str: MB_SIZE = 10**6 KB_SIZE = 10**3 if numel >= GB_SIZE: - return f'{numel / GB_SIZE:.1f}B' + return f"{numel / GB_SIZE:.1f}B" elif numel >= MB_SIZE: - return f'{numel / MB_SIZE:.1f}M' + return f"{numel / MB_SIZE:.1f}M" elif numel >= KB_SIZE: - return f'{numel / KB_SIZE:.1f}K' + return f"{numel / KB_SIZE:.1f}K" else: return str(numel) @@ -103,7 +102,7 @@ def model_size_formatter(numel: int) -> str: def set_cpu_maximum_parallelism(): conf_str = torch.__config__.parallel_info() inter_str = conf_str.split("hardware_concurrency() : ")[1] - max_concurrency = inter_str.split('\n')[0] + max_concurrency = inter_str.split("\n")[0] os.environ["OMP_NUM_THREADS"] = max_concurrency print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") @@ -130,7 +129,7 @@ def main(): WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" - PROF_FLAG = False # The flag of profiling, False by default + PROF_FLAG = False # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch(config={}) @@ -159,10 +158,9 @@ def main(): plugin = None if args.distplan.startswith("CAI_ZeRO"): - plugin = LowLevelZeroPlugin(stage=zero_stage, - reduce_bucket_size_in_m=12, - overlap_communication=True, - verbose=True) + plugin = LowLevelZeroPlugin( + stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True + ) elif args.distplan == "CAI_Gemini": plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) else: @@ -171,7 +169,7 @@ def main(): # build a highly optimized gpu/cpu optimizer optimizer = HybridAdam(model.parameters(), lr=1e-3) - logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + logger.info(get_mem_info(prefix="After init optim, "), ranks=[0]) elif args.distplan.startswith("Pytorch"): assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." model = model_builder(args.model_type)(checkpoint=True).cuda() @@ -180,6 +178,7 @@ def main(): optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) elif args.distplan.endswith("ZeRO"): from torch.distributed.optim import ZeroRedundancyOptimizer + optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) else: @@ -191,7 +190,7 @@ def main(): # model is shared after TP numel = get_model_size(model) logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") - logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + logger.info(get_mem_info(prefix="After init model, "), ranks=[0]) # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) @@ -213,19 +212,19 @@ def train_step(): torch.cuda.synchronize() fwd_end = time() fwd_time = fwd_end - start - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) + logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Forward "), ranks=[0]) booster.backward(loss, optimizer) torch.cuda.synchronize() bwd_end = time() bwd_time = bwd_end - fwd_end - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) + logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Backward "), ranks=[0]) optimizer.step() torch.cuda.synchronize() optim_time = time() - bwd_end step_time = time() - start - logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Optimizer step "), ranks=[0]) step_tflops = get_tflops_func(step_time) logger.info( @@ -235,10 +234,9 @@ def train_step(): if n >= WARMUP_STEPS: tflops_list.append(step_tflops) - demo_profiler = get_profile_context(PROF_FLAG, - WARMUP_STEPS, - NUM_STEPS - WARMUP_STEPS, - save_dir=f"profile/{get_time_stamp()}-demo") + demo_profiler = get_profile_context( + PROF_FLAG, WARMUP_STEPS, NUM_STEPS - WARMUP_STEPS, save_dir=f"profile/{get_time_stamp()}-demo" + ) with demo_profiler as prof: for n in range(NUM_STEPS): @@ -251,5 +249,5 @@ def train_step(): torch.cuda.synchronize() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/hybridparallelism/data.py b/examples/language/gpt/hybridparallelism/data.py index 981cedcca8c2..ef51f938dc4f 100644 --- a/examples/language/gpt/hybridparallelism/data.py +++ b/examples/language/gpt/hybridparallelism/data.py @@ -5,7 +5,6 @@ class GLUEDataBuilder: - task_text_field_map = { "cola": ["sentence"], "sst2": ["sentence"], @@ -84,10 +83,9 @@ def prepare_data(self): AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) def train_dataloader(self): - return self.plugin.prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) def val_dataloader(self): if len(self.eval_splits) == 1: @@ -108,7 +106,6 @@ def test_dataloader(self): ] def convert_to_features(self, example_batch): - # Either encode single sentence or sentence pairs if len(self.text_fields) > 1: texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) @@ -116,10 +113,9 @@ def convert_to_features(self, example_batch): texts_or_text_pairs = example_batch[self.text_fields[0]] # Tokenize the text/text pairs - features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, - max_length=self.max_seq_length, - padding='max_length', - truncation=True) + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) # Rename label to labels to make it easier to pass to model forward features["labels"] = example_batch["label"] diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 03e5ec91b3fe..62804eff8ea5 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -1,5 +1,4 @@ import argparse -from contextlib import nullcontext from typing import Callable, List, Union import evaluate @@ -7,7 +6,7 @@ import torch.distributed as dist import torch.nn as nn from data import GLUEDataBuilder -from torch.optim import Adam, Optimizer +from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm @@ -17,7 +16,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -104,7 +102,7 @@ def evaluate_subset(dataloader: DataLoader): results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) if coordinator.is_master() and results is not None: - results['loss'] = accum_loss.item() / coordinator.world_size + results["loss"] = accum_loss.item() / coordinator.world_size return results @@ -115,13 +113,20 @@ def evaluate_subset(dataloader: DataLoader): final_results = {} for split, sub_loader in zip(eval_splits, test_dataloader): results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, - train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + _criterion: Callable, + lr_scheduler: LRScheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() total_step = len(train_dataloader) @@ -129,22 +134,21 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: model.train() optimizer.zero_grad() train_dataloader_iter = iter(train_dataloader) - with tqdm(range(total_step), - desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', - disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: + with tqdm( + range(total_step), + desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", + disable=not (coordinator.is_master() or is_pp_last_stage), + ) as pbar: # Forward pass for _ in pbar: if use_pipeline: - outputs = booster.execute_pipeline(train_dataloader_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) # Backward and optimize if is_pp_last_stage: - loss = outputs['loss'] - pbar.set_postfix({'loss': loss.item()}) + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) else: data = next(train_dataloader_iter) data = move_to_cuda(data) @@ -152,7 +156,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: loss = _criterion(outputs, None) # Backward booster.backward(loss, optimizer) - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) optimizer.step() optimizer.zero_grad() @@ -164,24 +168,26 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], - help="plugin to use") + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], + help="plugin to use", + ) parser.add_argument( "--model_type", type=str, default="gpt2", help="only gpt2 now", ) - parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") - parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() - if args.model_type == 'gpt2': + if args.model_type == "gpt2": model_name = "gpt2" else: raise RuntimeError @@ -198,36 +204,35 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': - + elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases - plugin = HybridParallelPlugin(tp_size=1, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_all_optimization=True, - zero_stage=1, - precision='fp16', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + zero_stage=1, + precision="fp16", + initial_scale=1, + ) booster = Booster(plugin=plugin, **booster_kwargs) # ============================== # Prepare Dataloader # ============================== - data_builder = GLUEDataBuilder(model_name, - plugin, - args.task, - train_batch_size=BATCH_SIZE, - eval_batch_size=BATCH_SIZE) + data_builder = GLUEDataBuilder( + model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE + ) train_dataloader = data_builder.train_dataloader() test_dataloader = data_builder.test_dataloader() @@ -275,10 +280,9 @@ def _criterion(outputs, inputs): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, - optimizer, - criterion=_criterion, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler + ) # ============================== # Train model @@ -286,14 +290,22 @@ def _criterion(outputs, inputs): for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate_model(model, _criterion, test_dataloader, data_builder.num_labels, args.task, - data_builder.eval_splits, booster, coordinator) + results = evaluate_model( + model, + _criterion, + test_dataloader, + data_builder.num_labels, + args.task, + data_builder.eval_splits, + booster, + coordinator, + ) if coordinator.is_master(): print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py index 7bf53303948a..bc3dcb85cf1a 100644 --- a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py +++ b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py @@ -11,8 +11,10 @@ TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) # if you do no want zero, just comment out this dictionary -zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), - optimizer_config=dict(initial_scale=2**5)) +zero = dict( + model_config=dict(tensor_placement_policy="cuda", shard_strategy=TensorShardStrategy()), + optimizer_config=dict(initial_scale=2**5), +) optimizer = dict( type=HybridAdam, @@ -27,5 +29,5 @@ # for the current model implementation, mode can only be 1D or None parallel = dict( pipeline=1, - tensor=dict(size=2, mode='1d'), + tensor=dict(size=2, mode="1d"), ) diff --git a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py index 9f9816b3004f..7413764dad81 100644 --- a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py +++ b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py @@ -11,8 +11,10 @@ TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE) # if you do no want zero, just comment out this dictionary -zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()), - optimizer_config=dict(initial_scale=2**16)) +zero = dict( + model_config=dict(tensor_placement_policy="cuda", shard_strategy=TensorShardStrategy()), + optimizer_config=dict(initial_scale=2**16), +) optimizer = dict( type=HybridAdam, @@ -27,5 +29,5 @@ # for the current model implementation, mode can only be 1D or None parallel = dict( pipeline=1, - tensor=dict(size=2, mode='1d'), # for the current model implementation, mode can only be 1D or None + tensor=dict(size=2, mode="1d"), # for the current model implementation, mode can only be 1D or None ) diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py index fdfc57e9ba22..e61f73fd9eba 100644 --- a/examples/language/gpt/titans/dataset/webtext.py +++ b/examples/language/gpt/titans/dataset/webtext.py @@ -11,12 +11,11 @@ @DATASETS.register_module class WebtextDataset(Dataset): - def __init__(self, path: Optional[str] = None, seq_len=1024) -> None: super().__init__() if path is not None: root = os.path.dirname(path) - encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt') + encoded_data_cache_path = os.path.join(root, f"gpt_webtext_{seq_len}.pt") if os.path.isfile(encoded_data_cache_path): seq_len_, data, attention_mask = torch.load(encoded_data_cache_path) if seq_len_ == seq_len: @@ -26,12 +25,12 @@ def __init__(self, path: Optional[str] = None, seq_len=1024) -> None: raw_data = [] with open(path) as f: for line in f.readlines(): - raw_data.append(json.loads(line)['text']) - tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + raw_data.append(json.loads(line)["text"]) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.unk_token - encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt') - self.data = encoded_data['input_ids'] - self.attention_mask = encoded_data['attention_mask'] + encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors="pt") + self.data = encoded_data["input_ids"] + self.attention_mask = encoded_data["attention_mask"] else: self.data = torch.randint(0, 50257, (10240, seq_len)) self.attention_mask = torch.ones_like(self.data) @@ -40,4 +39,4 @@ def __len__(self): return len(self.data) def __getitem__(self, index): - return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index] + return {"input_ids": self.data[index], "attention_mask": self.attention_mask[index]}, self.data[index] diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py index a6c80394c50f..b2e3f71a5387 100644 --- a/examples/language/gpt/titans/model/embed.py +++ b/examples/language/gpt/titans/model/embed.py @@ -1,7 +1,6 @@ import torch import torch.nn.init as init from torch import Tensor -from torch import distributed as dist from torch import nn as nn from torch.nn import functional as F from torch.nn.parameter import Parameter @@ -12,7 +11,7 @@ from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row from colossalai.legacy.nn.layer.utils import divide -from colossalai.legacy.registry import LAYERS, LOSSES, MODELS +from colossalai.legacy.registry import LAYERS, LOSSES from colossalai.utils import get_current_device @@ -30,13 +29,9 @@ class VocabParallelEmbedding(torch.nn.Module): will ignore this embedding """ - def __init__(self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - num_tokentypes=0, - dtype=torch.float): + def __init__( + self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes=0, dtype=torch.float + ): super(VocabParallelEmbedding, self).__init__() self.hidden_size = hidden_size @@ -44,11 +39,11 @@ def __init__(self, # Word embeddings (parallel). self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype) - self._word_embeddings_key = 'word_embeddings' + self._word_embeddings_key = "word_embeddings" # Position embedding (serial). self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype) - self._position_embeddings_key = 'position_embeddings' + self._position_embeddings_key = "position_embeddings" # Initialize the position embeddings. # self.init_method(self.position_embeddings.weight) @@ -56,7 +51,7 @@ def __init__(self, # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' + self._tokentype_embeddings_key = "tokentype_embeddings" if self.num_tokentypes > 0: self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype) # Initialize the token-type embeddings. @@ -83,9 +78,9 @@ def add_tokentype_embeddings(self, num_tokentypes): This allows us to load the model normally and then add this embedding. """ if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') + raise Exception("tokentype embeddings is already initialized") if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) + print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True) self.num_tokentypes = num_tokentypes self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. @@ -112,19 +107,16 @@ def forward(self, input_ids, position_ids=None, tokentype_ids=None): embeddings = self.embedding_dropout(embeddings) return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars) if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars + ) return state_dict_ @@ -138,9 +130,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -150,9 +141,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. @@ -163,15 +153,14 @@ def load_state_dict(self, state_dict, strict=True): else: # for backward compatibility. for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] \ - = state_dict[key] + if "tokentype_embeddings" in key: + state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key] if len(state_dict_.keys()) > 0: self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: - print('***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', - flush=True) + print( + "***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True + ) class VocabParallelEmbedding1D(torch.nn.Module): @@ -193,37 +182,41 @@ def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None): # Set the details for compatibility. self.padding_idx = None self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None self.tensor_model_parallel_size = gpc.tensor_parallel_size # Divide the weight matrix along the vocabulary dimension. - self.vocab_start_index, self.vocab_end_index = \ - VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), - self.tensor_model_parallel_size) - self.num_embeddings_per_partition = self.vocab_end_index - \ - self.vocab_start_index + self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), self.tensor_model_parallel_size + ) + self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # Allocate weights and initialize. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs)) init.uniform_(self.weight, -1, 1) def forward(self, input_): if self.tensor_model_parallel_size > 1: # Build the mask. - input_mask = (input_ < self.vocab_start_index) | \ - (input_ >= self.vocab_end_index) + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, - self.scale_grad_by_freq, self.sparse) + output_parallel = F.embedding( + masked_input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) # Mask the output embedding. if self.tensor_model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 @@ -234,7 +227,6 @@ def forward(self, input_): @LOSSES.register_module class vocab_parallel_cross_entropy(nn.Module): - def __init__(self): super().__init__() @@ -242,20 +234,19 @@ def forward(self, vocab_parallel_logits, target): """Helper function for the cross entropy.""" vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous() target = target[..., 1:].contiguous() - return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), - target.view(-1)) + return _VocabParallelCrossEntropy.apply( + vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), target.view(-1) + ) class _VocabParallelCrossEntropy(torch.autograd.Function): - @staticmethod def forward(ctx, vocab_parallel_logits, target): - # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, - op=torch.distributed.ReduceOp.MAX, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_1D) + ) # Subtract the maximum value. vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) @@ -282,17 +273,17 @@ def forward(ctx, vocab_parallel_logits, target): predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce(predicted_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce( + predicted_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D) + ) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = vocab_parallel_logits torch.exp(vocab_parallel_logits, out=exp_logits) sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce(sum_exp_logits, - op=torch.distributed.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PARALLEL_1D)) + torch.distributed.all_reduce( + sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D) + ) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits @@ -304,7 +295,6 @@ def forward(ctx, vocab_parallel_logits, target): @staticmethod def backward(ctx, grad_output): - # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors @@ -316,7 +306,7 @@ def backward(ctx, grad_output): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) @@ -326,8 +316,8 @@ def backward(ctx, grad_output): class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last)""" + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" @staticmethod def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): @@ -393,11 +383,11 @@ def __init__( # Word embeddings (parallel). self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx) - self._word_embeddings_key = 'word_embeddings' + self._word_embeddings_key = "word_embeddings" # Position embedding (serial). self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) - self._position_embeddings_key = 'position_embeddings' + self._position_embeddings_key = "position_embeddings" # Initialize the position embeddings. # self.init_method(self.position_embeddings.weight) @@ -405,7 +395,7 @@ def __init__( # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. - self._tokentype_embeddings_key = 'tokentype_embeddings' + self._tokentype_embeddings_key = "tokentype_embeddings" if self.num_tokentypes > 0: self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. @@ -432,9 +422,9 @@ def add_tokentype_embeddings(self, num_tokentypes): This allows us to load the model normally and then add this embedding. """ if self.tokentype_embeddings is not None: - raise Exception('tokentype embeddings is already initialized') + raise Exception("tokentype embeddings is already initialized") if torch.distributed.get_rank() == 0: - print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) + print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True) self.num_tokentypes = num_tokentypes self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. @@ -460,19 +450,16 @@ def forward(self, input_ids, position_ids=None, tokentype_ids=None): embeddings = self.embedding_dropout(embeddings) return embeddings - def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False): """For easy load.""" state_dict_ = {} - state_dict_[self._word_embeddings_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars) + state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars) if self.num_tokentypes > 0: - state_dict_[self._tokentype_embeddings_key] \ - = self.tokentype_embeddings.state_dict( - destination, prefix, keep_vars) + state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict( + destination, prefix, keep_vars + ) return state_dict_ @@ -486,9 +473,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'word_embeddings' in key: - state_dict_[key.split('word_embeddings.')[1]] \ - = state_dict[key] + if "word_embeddings" in key: + state_dict_[key.split("word_embeddings.")[1]] = state_dict[key] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. @@ -498,9 +484,8 @@ def load_state_dict(self, state_dict, strict=True): # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] + if "position_embeddings" in key: + state_dict_[key.split("position_embeddings.")[1]] = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. @@ -511,15 +496,14 @@ def load_state_dict(self, state_dict, strict=True): else: # for backward compatibility. for key in state_dict.keys(): - if 'tokentype_embeddings' in key: - state_dict_[key.split('tokentype_embeddings.')[1]] \ - = state_dict[key] + if "tokentype_embeddings" in key: + state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key] if len(state_dict_.keys()) > 0: self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: - print('***WARNING*** expected tokentype embeddings in the ' - 'checkpoint but could not find it', - flush=True) + print( + "***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True + ) class HiddenParallelEmbedding1D(torch.nn.Module): @@ -542,21 +526,21 @@ def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx # Set the details for compatibility. self.padding_idx = padding_idx self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None # Allocate weights and initialize. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + factory_kwargs = {"device": get_current_device(), "dtype": dtype} self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs)) init.uniform_(self.weight, -1, 1) def forward(self, input_): - # Get the embeddings. - output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, - self.scale_grad_by_freq, self.sparse) + output_parallel = F.embedding( + input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse + ) # Reduce across all the model parallel GPUs. output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) @@ -584,11 +568,9 @@ def __init__( # self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx) # (hidden_size/q, vocab_size) self.synced_embed = False - self.head = Linear1D_Row(in_features=embed_dim, - out_features=vocab_size, - bias=False, - dtype=dtype, - parallel_input=False) + self.head = Linear1D_Row( + in_features=embed_dim, out_features=vocab_size, bias=False, dtype=dtype, parallel_input=False + ) def forward(self, x: Tensor) -> Tensor: if self.synced_embed: diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py index 746acbf7dccd..f8e2f42e11cb 100644 --- a/examples/language/gpt/titans/model/gpt1d.py +++ b/examples/language/gpt/titans/model/gpt1d.py @@ -18,18 +18,21 @@ from colossalai.utils import checkpoint __all__ = [ - 'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D' + "GPTMLP1D", + "GPTSelfAttention1D", + "GPTTransformerLayer1D", + "FusedGPTSelfAttention1D", + "FusedGPTTransformerLayer1D", ] class GPTMLP1D(ParallelLayer): - def __init__( self, in_features: int, mlp_ratio: int, - act_func: str = 'gelu', - dropout_prob: float = 0., + act_func: str = "gelu", + dropout_prob: float = 0.0, dtype=None, checkpoint: bool = False, skip_bias_add: bool = False, @@ -82,7 +85,6 @@ def forward(self, hidden_states: Tensor) -> Tensor: class GenericGPTSelfAttention1D(ParallelLayer): - def __init__( self, hidden_size: int, @@ -118,8 +120,10 @@ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_lay def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: query_key_value = self.query_key_value(hidden_states) - new_qkv_shape = query_key_value.shape[:-1] + \ - (self.num_attention_heads_per_partition, 3 * self.attention_head_size) + new_qkv_shape = query_key_value.shape[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.attention_head_size, + ) query_key_value = query_key_value.view(new_qkv_shape) query_key_value = query_key_value.permute((0, 2, 1, 3)) query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1) @@ -152,28 +156,32 @@ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor: class GPTSelfAttention1D(GenericGPTSelfAttention1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - attention_dropout_prob: float, - hidden_dropout_prob: float, - dtype=None, - checkpoint: bool = False, - max_position_embeddings=1024): - super().__init__(hidden_size, - num_attention_heads, - attention_dropout_prob, - hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings) + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False, + max_position_embeddings=1024, + ): + super().__init__( + hidden_size, + num_attention_heads, + attention_dropout_prob, + hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + ) self.softmax = nn.Softmax(dim=-1) max_positions = max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) @@ -181,7 +189,7 @@ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_lay attention_scores = attention_scores / math.sqrt(self.attention_head_size) # causal mask query_length, key_length = query_layer.size(-2), key_layer.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores)) if attention_mask is not None: # Apply the attention mask @@ -191,50 +199,56 @@ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_lay class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - attention_dropout_prob: float, - hidden_dropout_prob: float, - dtype=None, - checkpoint: bool = False, - max_position_embeddings=1024): - super().__init__(hidden_size, - num_attention_heads, - attention_dropout_prob, - hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings) - self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True, - input_in_bf16=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=True, - mask_func=None, - softmax_in_fp32=True, - scale=math.sqrt(self.attention_head_size)) + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + attention_dropout_prob: float, + hidden_dropout_prob: float, + dtype=None, + checkpoint: bool = False, + max_position_embeddings=1024, + ): + super().__init__( + hidden_size, + num_attention_heads, + attention_dropout_prob, + hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + ) + self.softmax = kernel.FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + mask_func=None, + softmax_in_fp32=True, + scale=math.sqrt(self.attention_head_size), + ) def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer): return self.softmax(attention_scores, attention_mask) class GenericGPTTransformerLayer1D(ParallelLayer): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - act_func: str = 'gelu', - mlp_ratio: float = 4.0, - attention_dropout_prob: float = 0., - hidden_dropout_prob: float = 0., - dtype=None, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - attention=None, - layer_norm=None): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + act_func: str = "gelu", + mlp_ratio: float = 4.0, + attention_dropout_prob: float = 0.0, + hidden_dropout_prob: float = 0.0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + attention=None, + layer_norm=None, + ): super().__init__() self.checkpoint = checkpoint self.dtype = dtype @@ -288,62 +302,68 @@ def forward(self, hidden_states, attention_mask): class GPTTransformerLayer1D(GenericGPTTransformerLayer1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - act_func: str = 'gelu', - mlp_ratio: float = 4, - attention_dropout_prob: float = 0, - hidden_dropout_prob: float = 0, - dtype=None, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 0.00001, - apply_post_layer_norm: bool = False): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + act_func: str = "gelu", + mlp_ratio: float = 4, + attention_dropout_prob: float = 0, + hidden_dropout_prob: float = 0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 0.00001, + apply_post_layer_norm: bool = False, + ): attention = GPTSelfAttention1D layer_norm = nn.LayerNorm - super().__init__(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attention_dropout_prob, - hidden_dropout_prob=hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm, - attention=attention, - layer_norm=layer_norm) + super().__init__( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + attention=attention, + layer_norm=layer_norm, + ) class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D): - - def __init__(self, - hidden_size: int, - num_attention_heads: int, - act_func: str = 'gelu', - mlp_ratio: float = 4, - attention_dropout_prob: float = 0, - hidden_dropout_prob: float = 0, - dtype=None, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 0.00001, - apply_post_layer_norm: bool = False): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + act_func: str = "gelu", + mlp_ratio: float = 4, + attention_dropout_prob: float = 0, + hidden_dropout_prob: float = 0, + dtype=None, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 0.00001, + apply_post_layer_norm: bool = False, + ): attention = FusedGPTSelfAttention1D layer_norm = kernel.LayerNorm - super().__init__(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attention_dropout_prob, - hidden_dropout_prob=hidden_dropout_prob, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm, - attention=attention, - layer_norm=layer_norm) + super().__init__( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attention_dropout_prob, + hidden_dropout_prob=hidden_dropout_prob, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + attention=attention, + layer_norm=layer_norm, + ) diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py index a9da246faf82..83158cb44e0c 100644 --- a/examples/language/gpt/titans/model/pipeline_gpt1d.py +++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py @@ -17,17 +17,16 @@ from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D __all__ = [ - 'GPT2_small_pipeline_1D', - 'GPT2_exlarge_pipeline_1D', - 'GPT3_pipeline_1D', - 'GPT2_exlarge_pipeline_hybrid', - 'GPT2_small_pipeline_hybrid', - 'GPT3_pipeline_hybrid', + "GPT2_small_pipeline_1D", + "GPT2_exlarge_pipeline_1D", + "GPT3_pipeline_1D", + "GPT2_exlarge_pipeline_hybrid", + "GPT2_small_pipeline_hybrid", + "GPT3_pipeline_hybrid", ] class GenericPipelineGPT(nn.Module): - def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None: super().__init__() self.embedding = embedding @@ -44,7 +43,7 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None): batch_size = hidden_states.shape[0] attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 for block in self.blocks: hidden_states, attention_mask = block(hidden_states, attention_mask) @@ -54,25 +53,26 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None): class PipelineGPT1D(GenericPipelineGPT): - - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: str = 'gelu', - mlp_ratio: int = 4.0, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - first: bool = False, - last: bool = False, - embed_split_hidden=False): + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0.0, + act_func: str = "gelu", + mlp_ratio: int = 4.0, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False, + ): embedding = None norm = None head = None @@ -83,19 +83,24 @@ def __init__(self, head_cls = HiddenParallelGPTLMHead1D if first: embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype) - blocks = nn.ModuleList([ - GPTTransformerLayer1D(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attn_drop_rate, - hidden_dropout_prob=drop_rate, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers) - ]) + blocks = nn.ModuleList( + [ + GPTTransformerLayer1D( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attn_drop_rate, + hidden_dropout_prob=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + ) + for _ in range(num_layers) + ] + ) if last: norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype) @@ -103,25 +108,26 @@ def __init__(self, class FusedPipelineGPT1D(GenericPipelineGPT): - - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: str = 'gelu', - mlp_ratio: int = 4.0, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - first: bool = False, - last: bool = False, - embed_split_hidden=False): + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0.0, + act_func: str = "gelu", + mlp_ratio: int = 4.0, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False, + ): embedding = None norm = None head = None @@ -132,19 +138,24 @@ def __init__(self, head_cls = HiddenParallelGPTLMHead1D if first: embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype) - blocks = nn.ModuleList([ - FusedGPTTransformerLayer1D(hidden_size, - num_attention_heads, - act_func=act_func, - mlp_ratio=mlp_ratio, - attention_dropout_prob=attn_drop_rate, - hidden_dropout_prob=drop_rate, - dtype=dtype, - checkpoint=checkpoint, - max_position_embeddings=max_position_embeddings, - layer_norm_epsilon=layer_norm_epsilon, - apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers) - ]) + blocks = nn.ModuleList( + [ + FusedGPTTransformerLayer1D( + hidden_size, + num_attention_heads, + act_func=act_func, + mlp_ratio=mlp_ratio, + attention_dropout_prob=attn_drop_rate, + hidden_dropout_prob=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + max_position_embeddings=max_position_embeddings, + layer_norm_epsilon=layer_norm_epsilon, + apply_post_layer_norm=apply_post_layer_norm, + ) + for _ in range(num_layers) + ] + ) if last: norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon) head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype) @@ -153,7 +164,7 @@ def __init__(self, def forward(self, hidden_states=None, input_ids=None, attention_mask=None): if self.embedding is not None: hidden_states = self.embedding(input_ids=input_ids) - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility for block in self.blocks: hidden_states, attention_mask = block(hidden_states, attention_mask) if self.norm is not None: @@ -162,44 +173,48 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None): class PipelineGPTHybrid(GenericPipelineGPT): - - def __init__(self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - vocab_size: int = 50304, - embed_drop_rate: float = 0., - act_func: str = 'gelu', - mlp_ratio: int = 4, - attn_drop_rate: float = 0., - drop_rate: float = 0., - dtype: torch.dtype = torch.float, - checkpoint: bool = False, - max_position_embeddings: int = 1024, - layer_norm_epsilon: float = 1e-5, - apply_post_layer_norm: bool = False, - first: bool = False, - last: bool = False, - embed_split_hidden=False): + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + vocab_size: int = 50304, + embed_drop_rate: float = 0.0, + act_func: str = "gelu", + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: bool = False, + max_position_embeddings: int = 1024, + layer_norm_epsilon: float = 1e-5, + apply_post_layer_norm: bool = False, + first: bool = False, + last: bool = False, + embed_split_hidden=False, + ): embedding = None norm = None head = None if first: - embedding = col_gpt.GPTEmbedding(hidden_size, - vocab_size, - max_position_embeddings, - dropout=embed_drop_rate, - dtype=dtype) - blocks = nn.ModuleList([ - col_gpt.GPTBlock(hidden_size, - num_attention_heads, - mlp_ratio=mlp_ratio, - attention_dropout=attn_drop_rate, - dropout=drop_rate, - dtype=dtype, - checkpoint=checkpoint, - activation=nn.functional.gelu) for _ in range(num_layers) - ]) + embedding = col_gpt.GPTEmbedding( + hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype + ) + blocks = nn.ModuleList( + [ + col_gpt.GPTBlock( + hidden_size, + num_attention_heads, + mlp_ratio=mlp_ratio, + attention_dropout=attn_drop_rate, + dropout=drop_rate, + dtype=dtype, + checkpoint=checkpoint, + activation=nn.functional.gelu, + ) + for _ in range(num_layers) + ] + ) if last: norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) # head = col_gpt.GPTLMHead(vocab_size=vocab_size, @@ -215,7 +230,7 @@ def _filter_kwargs(func, kwargs): return {k: v for k, v in kwargs.items() if k in sig.parameters} -def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs): +def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device("cuda"), **kwargs): logger = get_dist_logger() if gpc.is_initialized(ParallelMode.PIPELINE): @@ -233,10 +248,10 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] models = [] for start, end in parts: - kwargs['num_layers'] = end - start - kwargs['first'] = start == 0 - kwargs['last'] = end == num_layers - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + kwargs["num_layers"] = end - start + kwargs["first"] = start == 0 + kwargs["last"] = end == num_layers + logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers") chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device) if wrapper is not None: @@ -253,70 +268,82 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to numel = 0 for _, param in model.named_parameters(recurse=True): numel += param.numel() - logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') + logger.info(f"Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB") return model -def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs): +def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device("cuda"), fused=False, **kwargs): model = FusedPipelineGPT1D if fused else PipelineGPT1D return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs) -def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): +def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs) def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): - cfg = dict(hidden_size=768, - num_attention_heads=12, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=768, + num_attention_heads=12, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg) def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): - cfg = dict(hidden_size=1600, - num_attention_heads=32, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=1600, + num_attention_heads=32, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg) def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False): - cfg = dict(hidden_size=12288, - num_attention_heads=96, - checkpoint=checkpoint, - max_position_embeddings=2048, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=12288, + num_attention_heads=96, + checkpoint=checkpoint, + max_position_embeddings=2048, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg) def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): - cfg = dict(hidden_size=1600, - num_attention_heads=32, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=1600, + num_attention_heads=32, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg) def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): - cfg = dict(hidden_size=768, - num_attention_heads=12, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=768, + num_attention_heads=12, + checkpoint=checkpoint, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg) def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False): - cfg = dict(hidden_size=12288, - num_attention_heads=96, - checkpoint=checkpoint, - max_position_embeddings=2048, - dtype=dtype, - embed_split_hidden=embed_split_hidden) + cfg = dict( + hidden_size=12288, + num_attention_heads=96, + checkpoint=checkpoint, + max_position_embeddings=2048, + dtype=dtype, + embed_split_hidden=embed_split_hidden, + ) return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg) diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py index 3ed18b21fff5..b9d802f01cc9 100644 --- a/examples/language/gpt/titans/train_gpt.py +++ b/examples/language/gpt/titans/train_gpt.py @@ -14,7 +14,7 @@ from colossalai.legacy.zero.init_ctx import ZeroInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import LinearWarmupLR -from colossalai.utils import colo_set_process_memory_fraction, is_using_pp +from colossalai.utils import is_using_pp from colossalai.utils.timer import MultiTimer @@ -30,8 +30,8 @@ def calc_local_model_size(model: torch.nn.Module): def main(): parser = colossalai.get_default_parser() - parser.add_argument('--from_torch', default=False, action='store_true') - parser.add_argument('--use_dummy_dataset', default=False, action='store_true') + parser.add_argument("--from_torch", default=False, action="store_true") + parser.add_argument("--use_dummy_dataset", default=False, action="store_true") args = parser.parse_args() disable_existing_loggers() if args.from_torch: @@ -40,28 +40,27 @@ def main(): colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42) logger = get_dist_logger() - data_path = None if args.use_dummy_dataset else os.environ['DATA'] - logger.info(f'Build data loader from path {data_path}', ranks=[0]) + data_path = None if args.use_dummy_dataset else os.environ["DATA"] + logger.info(f"Build data loader from path {data_path}", ranks=[0]) train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN) - train_dataloader = utils.get_dataloader(train_ds, - seed=42, - batch_size=gpc.config.BATCH_SIZE, - pin_memory=True, - shuffle=True, - drop_last=True) - - logger.info('Build model', ranks=[0]) + train_dataloader = utils.get_dataloader( + train_ds, seed=42, batch_size=gpc.config.BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=True + ) + + logger.info("Build model", ranks=[0]) use_pipeline = is_using_pp() - use_interleaved = hasattr(gpc.config.model, 'num_chunks') - use_zero3 = hasattr(gpc.config, 'zero') + use_interleaved = hasattr(gpc.config.model, "num_chunks") + use_zero3 = hasattr(gpc.config, "zero") ctx = contextlib.nullcontext() if use_zero3: - ctx = ZeroInitContext(target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True) + ctx = ZeroInitContext( + target_device=torch.cuda.current_device(), + shard_strategy=gpc.config.zero.model_config.shard_strategy, + shard_param=True, + ) with ctx: - model = gpc.config.model.pop('type')(**gpc.config.model) + model = gpc.config.model.pop("type")(**gpc.config.model) if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList): model = nn.ModuleList([model]) @@ -70,25 +69,31 @@ def main(): else: numel = calc_local_model_size(model) - tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \ - * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4) - - criterion = getattr(gpc.config, 'loss_fn', None) + tflop = ( + numel + * gpc.config.BATCH_SIZE + * gpc.config.SEQ_LEN + * gpc.get_world_size(ParallelMode.MODEL) + * gpc.get_world_size(ParallelMode.DATA) + * 8 + / (1024**4) + ) + + criterion = getattr(gpc.config, "loss_fn", None) if criterion is not None: criterion = criterion.type() else: criterion = GPTLMLoss() - logger.info('Build optimizer', ranks=[0]) - optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer) + logger.info("Build optimizer", ranks=[0]) + optimizer = gpc.config.optimizer.pop("type")(model.parameters(), **gpc.config.optimizer) lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5) - engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model, - optimizer, - criterion, - train_dataloader=train_dataloader, - lr_scheduler=lr_scheduler) - global_batch_size = gpc.config.BATCH_SIZE * \ - gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) - logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0]) + engine, train_dataloader, _, lr_scheduler = colossalai.initialize( + model, optimizer, criterion, train_dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) + global_batch_size = ( + gpc.config.BATCH_SIZE * gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1) + ) + logger.info(f"Init done, global batch size = {global_batch_size}", ranks=[0]) timier = MultiTimer() trainer = Trainer(engine=engine, logger=logger, timer=timier) hook_list = [ @@ -98,16 +103,18 @@ def main(): hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop), hooks.LogMetricByStepHook(), hooks.LogMemoryByEpochHook(logger), - # hooks.LogMemoryByEpochHook(logger), - # hooks.LogTimingByEpochHook(timer, logger), + # hooks.LogMemoryByEpochHook(logger), + # hooks.LogTimingByEpochHook(timer, logger), ] - trainer.fit(train_dataloader=train_dataloader, - epochs=gpc.config.NUM_EPOCHS, - test_interval=1, - hooks=hook_list, - display_progress=True, - return_output_label=False) + trainer.fit( + train_dataloader=train_dataloader, + epochs=gpc.config.NUM_EPOCHS, + test_interval=1, + hooks=hook_list, + display_progress=True, + return_output_label=False, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/llama2/attn.py b/examples/language/llama2/attn.py index 15f76647c87b..2b2356b18b70 100644 --- a/examples/language/llama2/attn.py +++ b/examples/language/llama2/attn.py @@ -9,12 +9,14 @@ SUPPORT_FLASH2 = False try: import xformers.ops as xops + SUPPORT_XFORMERS = True except ImportError: pass try: from flash_attn import flash_attn_func + SUPPORT_FLASH2 = True except ImportError: pass @@ -62,10 +64,9 @@ def llama_flash_attention( if SUPPORT_FLASH2: attn_output = flash_attn_func(query_states, key_states, value_states, causal=True) else: - attn_output = xops.memory_efficient_attention(query_states, - key_states, - value_states, - attn_bias=xops.LowerTriangularMask()) + attn_output = xops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask() + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index 1b947cef9080..ce13ebbf617d 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -25,21 +25,22 @@ # ============================== MODEL_CONFIGS = { - '7b': - LlamaConfig(max_position_embeddings=4096), - '13b': - LlamaConfig(hidden_size=5120, - intermediate_size=13824, - num_hidden_layers=40, - num_attention_heads=40, - max_position_embeddings=4096), - '70b': - LlamaConfig(hidden_size=8192, - intermediate_size=28672, - num_hidden_layers=80, - num_attention_heads=64, - max_position_embeddings=4096, - num_key_value_heads=8), + "7b": LlamaConfig(max_position_embeddings=4096), + "13b": LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + max_position_embeddings=4096, + ), + "70b": LlamaConfig( + hidden_size=8192, + intermediate_size=28672, + num_hidden_layers=80, + num_attention_heads=64, + max_position_embeddings=4096, + num_key_value_heads=8, + ), } @@ -48,31 +49,31 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration') - parser.add_argument('-p', - '--plugin', - choices=['gemini', 'gemini_auto', 'fsdp', 'fsdp_cpu', '3d', '3d_cpu'], - default='gemini', - help='Choose which plugin to use') - parser.add_argument('-b', '--batch_size', type=int, default=2, help='Batch size') - parser.add_argument('-s', '--num_steps', type=int, default=5, help='Number of steps to run') - parser.add_argument('-i', '--ignore_steps', type=int, default=2, help='Number of steps to ignore') - parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') - parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') - parser.add_argument('-w', - '--warmup_ratio', - type=float, - default=0.8, - help='warm up ratio of non-model data. Only for gemini-auto') - parser.add_argument('-m', '--memory_limit', type=int, help='Gemini memory limit in mb') - parser.add_argument('-x', '--xformers', action='store_true', help='Use xformers') - parser.add_argument('--shard_param_frac', type=float, default=1.0, help='Shard param fraction. Only for gemini') - parser.add_argument('--offload_optim_frac', type=float, default=0.0, help='Offload optim fraction. Only for gemini') - parser.add_argument('--offload_param_frac', type=float, default=0.0, help='Offload param fraction. Only for gemini') - parser.add_argument('--tp', type=int, default=1, help='Tensor parallel size') - parser.add_argument('--pp', type=int, default=1, help='Pipeline parallel size') - parser.add_argument('--mbs', type=int, default=1) - parser.add_argument('--zero', type=int, default=0) + parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Batch size") + parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run") + parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument( + "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto" + ) + parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb") + parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers") + parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini") + parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") + parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") + parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") + parser.add_argument("--mbs", type=int, default=1) + parser.add_argument("--zero", type=int, default=0) args = parser.parse_args() colossalai.launch_from_torch({}) @@ -85,56 +86,67 @@ def empty_init(): # Initialize Booster # ============================== use_empty_init = True - if args.plugin == 'gemini': - plugin = GeminiPlugin(precision='bf16', - shard_param_frac=args.shard_param_frac, - offload_optim_frac=args.offload_optim_frac, - offload_param_frac=args.offload_param_frac) - elif args.plugin == 'gemini_auto': - plugin = GeminiPlugin(placement_policy='auto', precision='bf16', warmup_non_model_data_ratio=args.warmup_ratio) - elif args.plugin == 'fsdp': + if args.plugin == "gemini": + plugin = GeminiPlugin( + precision="bf16", + shard_param_frac=args.shard_param_frac, + offload_optim_frac=args.offload_optim_frac, + offload_param_frac=args.offload_param_frac, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio) + elif args.plugin == "fsdp": if use_empty_init: plugin = TorchFSDPPlugin( - mixed_precision=MixedPrecision(param_dtype=torch.float16, - reduce_dtype=torch.float16, - buffer_dtype=torch.float16), + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), param_init_fn=empty_init(), ) else: - plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision( - param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16)) - elif args.plugin == 'fsdp_cpu': + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ) + ) + elif args.plugin == "fsdp_cpu": if use_empty_init: plugin = TorchFSDPPlugin( - mixed_precision=MixedPrecision(param_dtype=torch.float16, - reduce_dtype=torch.float16, - buffer_dtype=torch.float16), + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), ) else: - plugin = TorchFSDPPlugin(mixed_precision=MixedPrecision(param_dtype=torch.float16, - reduce_dtype=torch.float16, - buffer_dtype=torch.float16), - cpu_offload=CPUOffload(offload_params=True)) - elif args.plugin == '3d': - plugin = HybridParallelPlugin(tp_size=args.tp, - pp_size=args.pp, - zero_stage=args.zero, - enable_fused_normalization=True, - num_microbatches=args.mbs, - precision='bf16') - elif args.plugin == '3d_cpu': - plugin = HybridParallelPlugin(tp_size=args.tp, - pp_size=args.pp, - zero_stage=args.zero, - cpu_offload=True, - enable_fused_normalization=True, - num_microbatches=args.mbs, - initial_scale=2**8, - precision='bf16') + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + cpu_offload=CPUOffload(offload_params=True), + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + enable_fused_normalization=True, + num_microbatches=args.mbs, + precision="bf16", + ) + elif args.plugin == "3d_cpu": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero, + cpu_offload=True, + enable_fused_normalization=True, + num_microbatches=args.mbs, + initial_scale=2**8, + precision="bf16", + ) else: - raise ValueError(f'Unknown plugin {args.plugin}') + raise ValueError(f"Unknown plugin {args.plugin}") booster = Booster(plugin=plugin) @@ -144,17 +156,19 @@ def empty_init(): dp_size = plugin.dp_size if isinstance(plugin, HybridParallelPlugin) else coordinator.world_size config = MODEL_CONFIGS[args.config] - dataset = RandomDataset(num_samples=args.batch_size * args.num_steps * dp_size, - max_length=args.max_length, - vocab_size=config.vocab_size) + dataset = RandomDataset( + num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size + ) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # ============================== # Initialize Model and Optimizer # ============================== - init_ctx = LazyInitContext( - default_device=get_current_device()) if isinstance(plugin, - (GeminiPlugin, HybridParallelPlugin)) else nullcontext() + init_ctx = ( + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) with init_ctx: model = LlamaForCausalLM(config) @@ -163,38 +177,36 @@ def empty_init(): model.gradient_checkpointing_enable() if args.xformers: - assert SUPPORT_FLASH, 'Use flash attention while xfomers is not installed' + assert SUPPORT_FLASH, "Use flash attention while xfomers is not installed" replace_xformers(model) model_numel = get_model_numel(model) - coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') - performance_evaluator = PerformanceEvaluator(model_numel, - args.grad_checkpoint, - args.ignore_steps, - dp_world_size=dp_size) + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + performance_evaluator = PerformanceEvaluator( + model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size + ) optimizer = HybridAdam(model.parameters()) torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master( - f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: data_iter = iter(dataloader) - for step in tqdm(range(len(dataloader)), desc='Step', disable=not coordinator.is_master()): + for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): performance_evaluator.on_step_start(step) - booster.execute_pipeline(data_iter, - model, - criterion=lambda outputs, inputs: outputs[0], - optimizer=optimizer, - return_loss=False) + booster.execute_pipeline( + data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False + ) optimizer.step() optimizer.zero_grad() performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) else: - for step, batch in enumerate(tqdm(dataloader, desc='Step', disable=not coordinator.is_master())): + for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] @@ -204,8 +216,8 @@ def empty_init(): performance_evaluator.on_step_end(**batch) performance_evaluator.on_fit_end() - coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/llama2/data_utils.py b/examples/language/llama2/data_utils.py index 25d0e1bd9f46..a438833e1680 100644 --- a/examples/language/llama2/data_utils.py +++ b/examples/language/llama2/data_utils.py @@ -12,21 +12,22 @@ class StatefulDistributedSampler(DistributedSampler): - - def __init__(self, - dataset: Dataset, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = True, - seed: int = 0, - drop_last: bool = False) -> None: + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) self.start_index: int = 0 def __iter__(self) -> Iterator: iterator = super().__iter__() indices = list(iterator) - indices = indices[self.start_index:] + indices = indices[self.start_index :] return iter(indices) def __len__(self) -> int: @@ -36,15 +37,17 @@ def set_start_index(self, start_index: int) -> None: self.start_index = start_index -def prepare_dataloader(dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - process_group: Optional[ProcessGroup] = None, - **kwargs): +def prepare_dataloader( + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + **kwargs, +): r""" Prepare a dataloader for distributed training. The dataloader will be wrapped by `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. @@ -68,10 +71,9 @@ def prepare_dataloader(dataset, """ _kwargs = kwargs.copy() process_group = process_group or _get_default_group() - sampler = StatefulDistributedSampler(dataset, - num_replicas=process_group.size(), - rank=process_group.rank(), - shuffle=shuffle) + sampler = StatefulDistributedSampler( + dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle + ) # Deterministic dataloader def seed_worker(worker_id): @@ -80,28 +82,29 @@ def seed_worker(worker_id): torch.manual_seed(worker_seed) random.seed(worker_seed) - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) def load_json(file_path: str): - with open(file_path, 'r') as f: + with open(file_path, "r") as f: return json.load(f) def save_json(data, file_path: str): - with open(file_path, 'w') as f: + with open(file_path, "w") as f: json.dump(data, f, indent=4) class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length @@ -113,7 +116,7 @@ def __len__(self): def __getitem__(self, idx): return { - 'input_ids': self.input_ids[idx], - 'attention_mask': self.attention_mask[idx], - 'labels': self.input_ids[idx] + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], } diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index 0efbf193c9a9..33aa1d33e6ba 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -39,20 +39,20 @@ def format_numel_str(numel: int) -> str: M = 1024**2 K = 1024 if numel >= B: - return f'{numel / B:.2f} B' + return f"{numel / B:.2f} B" elif numel >= M: - return f'{numel / M:.2f} M' + return f"{numel / M:.2f} M" elif numel >= K: - return f'{numel / K:.2f} K' + return f"{numel / K:.2f} K" else: - return f'{numel}' + return f"{numel}" def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample['prompt'] + sample['completion'] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length) + texts = [sample["prompt"] + sample["completion"] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) data = {k: v.cuda() for k, v in data.items()} - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -62,30 +62,40 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: return tensor -def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int, - batch_size: int, coordinator: DistCoordinator, save_dir: str): - save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}') - os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, 'model'), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler')) +def save( + booster: Booster, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, + save_dir: str, +): + save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "model"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) running_states = { - 'epoch': epoch, - 'step': step, - 'sample_start_index': step * batch_size, + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, } if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, 'running_states.json')) + save_json(running_states, os.path.join(save_dir, "running_states.json")) -def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, - load_dir: str) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, 'model')) - booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer')) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler')) - running_states = load_json(os.path.join(load_dir, 'running_states.json')) - return running_states['epoch'], running_states['step'], running_states['sample_start_index'] +def load( + booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str +) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, "model")) + booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) + running_states = load_json(os.path.join(load_dir, "running_states.json")) + return running_states["epoch"], running_states["step"], running_states["sample_start_index"] def _criterion(outputs, inputs): @@ -97,27 +107,29 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('--model_path', type=str, help="pretrained checkpoint path, used with mode==finetune") - parser.add_argument('-p', - '--plugin', - choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'], - default='gemini', - help='Choose which plugin to use') - parser.add_argument('-d', '--dataset', type=str, default='yizhongw/self_instruct', help='Data set path') - parser.add_argument('--task_name', type=str, default="super_natural_instructions", help='task to run') - parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs') - parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size') - parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') - parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay') - parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') - parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') - parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision') - parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval') - parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory') - parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint') - parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping') - parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory') - parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention') + parser.add_argument("--model_path", type=str, help="pretrained checkpoint path, used with mode==finetune") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument("-d", "--dataset", type=str, default="yizhongw/self_instruct", help="Data set path") + parser.add_argument("--task_name", type=str, default="super_natural_instructions", help="task to run") + parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") + parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") + parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") args = parser.parse_args() # ============================== @@ -129,36 +141,34 @@ def main(): # ============================== # Initialize Booster # ============================== - if args.plugin == 'gemini': + if args.plugin == "gemini": plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == 'gemini_auto': - plugin = GeminiPlugin(precision=args.mixed_precision, - placement_policy='auto', - initial_scale=2**16, - max_norm=args.grad_clip) - elif args.plugin == 'zero2': - plugin = LowLevelZeroPlugin(stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip) - elif args.plugin == 'zero2_cpu': - plugin = LowLevelZeroPlugin(stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - cpu_offload=True, - max_norm=args.grad_clip) - elif args.plugin == 'hybrid_parallel': + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip + ) + elif args.plugin == "hybrid_parallel": # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin(tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision='fp32', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=4, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_jit_fused=False, + zero_stage=0, + precision="fp32", + initial_scale=1, + ) else: - raise ValueError(f'Unknown plugin {args.plugin}') + raise ValueError(f"Unknown plugin {args.plugin}") booster = Booster(plugin=plugin) @@ -179,8 +189,9 @@ def main(): config = LlamaConfig.from_pretrained(args.model_path) # use lazy init when using GeminiPlugin - init_ctx = LazyInitContext( - default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + init_ctx = ( + LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + ) with init_ctx: model = LlamaForCausalLM(config) @@ -188,57 +199,56 @@ def main(): # ============================== # Initialize Tokenizer, Dataset and Dataloader # ============================== - tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer') + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 tokenizer.pad_token = tokenizer.unk_token dataset = load_dataset(args.dataset, args.task_name) - train_ds = dataset['train'] - dataloader = prepare_dataloader(train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_finetune, - tokenizer=tokenizer, - max_length=args.max_length)) + train_ds = dataset["train"] + dataloader = prepare_dataloader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch_for_finetune, tokenizer=tokenizer, max_length=args.max_length), + ) if args.grad_checkpoint: model.gradient_checkpointing_enable() if args.flash_attention: - assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed' + assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed" replace_xformers(model) model_numel = get_model_numel(model) - coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) total_step = args.num_epochs * len(dataloader) - lr_scheduler = CosineAnnealingWarmupLR(optimizer, - total_steps=total_step, - warmup_steps=math.ceil(total_step * 0.03), - eta_min=0.1 * args.lr) - default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16 + lr_scheduler = CosineAnnealingWarmupLR( + optimizer, total_steps=total_step, warmup_steps=math.ceil(total_step * 0.03), eta_min=0.1 * args.lr + ) + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost(model, - optimizer, - dataloader=dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler + ) torch.set_default_dtype(torch.float) booster.load_model(model, args.model_path) - coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master( - f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) # load checkpoint if specified start_epoch = 0 start_step = 0 sampler_start_idx = 0 if args.load is not None: - coordinator.print_on_master('Loading checkpoint') + coordinator.print_on_master("Loading checkpoint") start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}') + coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") num_steps_per_epoch = len(dataloader) @@ -249,19 +259,18 @@ def main(): step_nums = num_steps_per_epoch - start_step dataloader_iter = iter(dataloader) - with tqdm(range(step_nums), - desc=f'Epoch {epoch}', - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step) as pbar: + with tqdm( + range(step_nums), + desc=f"Epoch {epoch}", + disable=not print_flag, + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: for step in pbar: if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) loss = outputs["loss"] else: batch = next(dataloader_iter) @@ -276,20 +285,29 @@ def main(): if not use_pipeline: all_reduce_mean(loss) if print_flag: - pbar.set_postfix({'loss': loss.item()}) - writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step) + pbar.set_postfix({"loss": loss.item()}) + writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f'Saving checkpoint') - save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator, - args.save_dir) - coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}') + coordinator.print_on_master(f"Saving checkpoint") + save( + booster, + model, + optimizer, + lr_scheduler, + epoch, + step + 1, + args.batch_size, + coordinator, + args.save_dir, + ) + coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(0) start_step = 0 - coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/llama2/model_utils.py b/examples/language/llama2/model_utils.py index 431ff5cfb446..63569bc61143 100644 --- a/examples/language/llama2/model_utils.py +++ b/examples/language/llama2/model_utils.py @@ -23,10 +23,10 @@ def format_numel_str(numel: int) -> str: M = 1024**2 K = 1024 if numel >= B: - return f'{numel / B:.2f} B' + return f"{numel / B:.2f} B" elif numel >= M: - return f'{numel / M:.2f} M' + return f"{numel / M:.2f} M" elif numel >= K: - return f'{numel / K:.2f} K' + return f"{numel / K:.2f} K" else: - return f'{numel}' + return f"{numel}" diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index 711b99c54360..a57c1e0e9ae3 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -10,9 +10,9 @@ def divide(x: float, y: float) -> float: if y == 0: - return float('inf') - elif y == float('inf'): - return float('nan') + return float("inf") + elif y == float("inf"): + return float("nan") return x / y @@ -27,10 +27,9 @@ def all_reduce_mean(x: float, world_size: int) -> float: class Timer: - def __init__(self) -> None: self.start_time: Optional[float] = None - self.duration: float = 0. + self.duration: float = 0.0 def start(self) -> None: self.start_time = time() @@ -41,7 +40,7 @@ def end(self) -> None: self.start_time = None def reset(self) -> None: - self.duration = 0. + self.duration = 0.0 class PerformanceEvaluator: @@ -56,11 +55,13 @@ class PerformanceEvaluator: ignore_episodes: The number of episodes to ignore when calculating the performance. """ - def __init__(self, - model_numel: int, - enable_grad_checkpoint: bool = False, - ignore_steps: int = 0, - dp_world_size: Optional[int] = None) -> None: + def __init__( + self, + model_numel: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None, + ) -> None: self.model_numel = model_numel self.enable_grad_checkpoint = enable_grad_checkpoint self.ignore_steps = ignore_steps @@ -96,7 +97,9 @@ def on_fit_end(self) -> None: mp_world_size = self.coordinator.world_size // self.dp_world_size avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size self.coordinator.print_on_master( - f'num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, ' - f'avg_throughput: {avg_throughput}') + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " + f"avg_throughput: {avg_throughput}" + ) self.coordinator.print_on_master( - f'Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}') + f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" + ) diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index 0eeac4035401..6cc73b6265a4 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -29,21 +29,22 @@ from colossalai.utils import get_current_device MODEL_CONFIGS = { - '7b': - LlamaConfig(max_position_embeddings=4096), - '13b': - LlamaConfig(hidden_size=5120, - intermediate_size=13824, - num_hidden_layers=40, - num_attention_heads=40, - max_position_embeddings=4096), - '70b': - LlamaConfig(hidden_size=8192, - intermediate_size=28672, - num_hidden_layers=80, - num_attention_heads=64, - max_position_embeddings=4096, - num_key_value_heads=8), + "7b": LlamaConfig(max_position_embeddings=4096), + "13b": LlamaConfig( + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=40, + num_attention_heads=40, + max_position_embeddings=4096, + ), + "70b": LlamaConfig( + hidden_size=8192, + intermediate_size=28672, + num_hidden_layers=80, + num_attention_heads=64, + max_position_embeddings=4096, + num_key_value_heads=8, + ), } @@ -56,20 +57,20 @@ def format_numel_str(numel: int) -> str: M = 1024**2 K = 1024 if numel >= B: - return f'{numel / B:.2f} B' + return f"{numel / B:.2f} B" elif numel >= M: - return f'{numel / M:.2f} M' + return f"{numel / M:.2f} M" elif numel >= K: - return f'{numel / K:.2f} K' + return f"{numel / K:.2f} K" else: - return f'{numel}' + return f"{numel}" def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): - texts = [sample['text'] for sample in batch] - data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length) + texts = [sample["text"] for sample in batch] + data = tokenizer(texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) data = {k: v.cuda() for k, v in data.items()} - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -79,30 +80,40 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: return tensor -def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int, - batch_size: int, coordinator: DistCoordinator, save_dir: str): - save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}') - os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True) - - booster.save_model(model, os.path.join(save_dir, 'model'), shard=True) - booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True) - booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler')) +def save( + booster: Booster, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, + save_dir: str, +): + save_dir = os.path.join(save_dir, f"epoch{epoch}-step{step}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "model"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) running_states = { - 'epoch': epoch, - 'step': step, - 'sample_start_index': step * batch_size, + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, } if coordinator.is_master(): - save_json(running_states, os.path.join(save_dir, 'running_states.json')) + save_json(running_states, os.path.join(save_dir, "running_states.json")) -def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, - load_dir: str) -> Tuple[int, int, int]: - booster.load_model(model, os.path.join(load_dir, 'model')) - booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer')) - booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler')) - running_states = load_json(os.path.join(load_dir, 'running_states.json')) - return running_states['epoch'], running_states['step'], running_states['sample_start_index'] +def load( + booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str +) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, "model")) + booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) + running_states = load_json(os.path.join(load_dir, "running_states.json")) + return running_states["epoch"], running_states["step"], running_states["sample_start_index"] def _criterion(outputs, inputs): @@ -114,31 +125,31 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration') - parser.add_argument('-p', - '--plugin', - choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'], - default='gemini', - help='Choose which plugin to use') - parser.add_argument('-d', - '--dataset', - type=str, - default='togethercomputer/RedPajama-Data-1T-Sample', - help='Data set path') - parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs') - parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size') - parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') - parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay') - parser.add_argument('-s', '--warmup_steps', type=int, default=2000, help='Warmup steps') - parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') - parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') - parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision') - parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval') - parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory') - parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint') - parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping') - parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory') - parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention') + parser.add_argument("-c", "--config", type=str, default="7b", help="Model configuration") + parser.add_argument( + "-p", + "--plugin", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "hybrid_parallel"], + default="gemini", + help="Choose which plugin to use", + ) + parser.add_argument( + "-d", "--dataset", type=str, default="togethercomputer/RedPajama-Data-1T-Sample", help="Data set path" + ) + parser.add_argument("-e", "--num_epochs", type=int, default=1, help="Number of epochs") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="Local batch size") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("-w", "--weigth_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("-s", "--warmup_steps", type=int, default=2000, help="Warmup steps") + parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing") + parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length") + parser.add_argument("-x", "--mixed_precision", default="fp16", choices=["fp16", "bf16"], help="Mixed precision") + parser.add_argument("-i", "--save_interval", type=int, default=1000, help="Save interval") + parser.add_argument("-o", "--save_dir", type=str, default="checkpoint", help="Checkpoint directory") + parser.add_argument("-f", "--load", type=str, default=None, help="Load checkpoint") + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("-t", "--tensorboard_dir", type=str, default="tb_logs", help="Tensorboard directory") + parser.add_argument("-a", "--flash_attention", action="store_true", help="Use Flash Attention") args = parser.parse_args() # ============================== @@ -150,36 +161,34 @@ def main(): # ============================== # Initialize Booster # ============================== - if args.plugin == 'gemini': + if args.plugin == "gemini": plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) - elif args.plugin == 'gemini_auto': - plugin = GeminiPlugin(precision=args.mixed_precision, - placement_policy='auto', - initial_scale=2**16, - max_norm=args.grad_clip) - elif args.plugin == 'zero2': - plugin = LowLevelZeroPlugin(stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - max_norm=args.grad_clip) - elif args.plugin == 'zero2_cpu': - plugin = LowLevelZeroPlugin(stage=2, - precision=args.mixed_precision, - initial_scale=2**16, - cpu_offload=True, - max_norm=args.grad_clip) - elif args.plugin == 'hybrid_parallel': + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, precision=args.mixed_precision, initial_scale=2**16, cpu_offload=True, max_norm=args.grad_clip + ) + elif args.plugin == "hybrid_parallel": # modify the param accordingly, default configuration is for llama2-7b - plugin = HybridParallelPlugin(tp_size=4, - pp_size=2, - num_microbatches=None, - microbatch_size=1, - enable_jit_fused=False, - zero_stage=0, - precision='fp32', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=4, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_jit_fused=False, + zero_stage=0, + precision="fp32", + initial_scale=1, + ) else: - raise ValueError(f'Unknown plugin {args.plugin}') + raise ValueError(f"Unknown plugin {args.plugin}") booster = Booster(plugin=plugin) @@ -197,27 +206,28 @@ def main(): # ============================== # Initialize Tokenizer, Dataset and Dataloader # ============================== - tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer') + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") # follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 tokenizer.pad_token = tokenizer.unk_token dataset = load_dataset(args.dataset) - train_ds = dataset['train'] - dataloader = prepare_dataloader(train_ds, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=partial(tokenize_batch_for_pretrain, - tokenizer=tokenizer, - max_length=args.max_length)) + train_ds = dataset["train"] + dataloader = prepare_dataloader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=partial(tokenize_batch_for_pretrain, tokenizer=tokenizer, max_length=args.max_length), + ) # ============================== # Initialize Model, Optimizer and LR Scheduler # ============================== config = MODEL_CONFIGS[args.config] # use lazy init when using GeminiPlugin - init_ctx = LazyInitContext( - default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + init_ctx = ( + LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() + ) with init_ctx: model = LlamaForCausalLM(config) @@ -225,37 +235,36 @@ def main(): if args.grad_checkpoint: model.gradient_checkpointing_enable() if args.flash_attention: - assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed' + assert SUPPORT_XFORMERS, "Use flash attention while xfomers is not installed" replace_xformers(model) model_numel = get_model_numel(model) - coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') + coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) - lr_scheduler = CosineAnnealingWarmupLR(optimizer, - total_steps=args.num_epochs * len(dataloader), - warmup_steps=args.warmup_steps, - eta_min=0.1 * args.lr) - default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16 + lr_scheduler = CosineAnnealingWarmupLR( + optimizer, total_steps=args.num_epochs * len(dataloader), warmup_steps=args.warmup_steps, eta_min=0.1 * args.lr + ) + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 torch.set_default_dtype(default_dtype) - model, optimizer, _, dataloader, lr_scheduler = booster.boost(model, - optimizer, - dataloader=dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model, optimizer, dataloader=dataloader, lr_scheduler=lr_scheduler + ) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master( - f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" + ) # load checkpoint if specified start_epoch = 0 start_step = 0 sampler_start_idx = 0 if args.load is not None: - coordinator.print_on_master('Loading checkpoint') + coordinator.print_on_master("Loading checkpoint") start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) - coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}') + coordinator.print_on_master(f"Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}") num_steps_per_epoch = len(dataloader) @@ -266,19 +275,18 @@ def main(): step_nums = num_steps_per_epoch - start_step dataloader_iter = iter(dataloader) - with tqdm(range(step_nums), - desc=f'Epoch {epoch}', - disable=not print_flag, - total=num_steps_per_epoch, - initial=start_step) as pbar: + with tqdm( + range(step_nums), + desc=f"Epoch {epoch}", + disable=not print_flag, + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: for step in pbar: if use_pipeline: - outputs = booster.execute_pipeline(dataloader_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) loss = outputs["loss"] else: batch = next(dataloader_iter) @@ -293,20 +301,29 @@ def main(): if not use_pipeline: all_reduce_mean(loss) if print_flag: - pbar.set_postfix({'loss': loss.item()}) - writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step) + pbar.set_postfix({"loss": loss.item()}) + writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step) if args.save_interval > 0 and (step + 1) % args.save_interval == 0: - coordinator.print_on_master(f'Saving checkpoint') - save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator, - args.save_dir) - coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}') + coordinator.print_on_master(f"Saving checkpoint") + save( + booster, + model, + optimizer, + lr_scheduler, + epoch, + step + 1, + args.batch_size, + coordinator, + args.save_dir, + ) + coordinator.print_on_master(f"Saved checkpoint at epoch {epoch} step {step + 1}") # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(0) start_step = 0 - coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') + coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/language/opt/args.py b/examples/language/opt/args.py index 77fa12bc8a0c..1ec19094e19e 100644 --- a/examples/language/opt/args.py +++ b/examples/language/opt/args.py @@ -2,36 +2,35 @@ def parse_demo_args(): - parser = get_default_parser() - parser.add_argument("--model_name_or_path", - type=str, - default="facebook/opt-350m", - help="Path to pretrained model or model identifier from huggingface.co/models.") - parser.add_argument("--output_path", - type=str, - default="./output_model.bin", - help="The path of your saved model after finetuning.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="facebook/opt-350m", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_path", type=str, default="./output_model.bin", help="The path of your saved model after finetuning." + ) parser.add_argument( "--plugin", type=str, default="gemini", - help= - "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'.", ) parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.") - parser.add_argument("--batch_size", - type=int, - default=32, - help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.") - parser.add_argument("--warmup_ratio", - type=float, - default=0.1, - help="Ratio of warmup steps against total training steps.") + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps against total training steps." + ) parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") @@ -40,25 +39,28 @@ def parse_demo_args(): def parse_benchmark_args(): - parser = get_default_parser() - parser.add_argument("--model_name_or_path", - type=str, - default="facebook/opt-125m", - help="Path to pretrained model or model identifier from huggingface.co/models.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="facebook/opt-125m", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) parser.add_argument( "--plugin", type=str, default="gemini", - help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'.") - parser.add_argument("--batch_size", - type=int, - default=32, - help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.") + help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'.", + ) + parser.add_argument( + "--batch_size", type=int, default=32, help="Batch size (per dp group) for the training dataloader." + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") diff --git a/examples/language/opt/data.py b/examples/language/opt/data.py index 6cfffb5fc95b..9b9cc59518ab 100644 --- a/examples/language/opt/data.py +++ b/examples/language/opt/data.py @@ -1,37 +1,38 @@ import torch -from torch.utils.data import Dataset from datasets import load_dataset +from torch.utils.data import Dataset class NetflixDataset(Dataset): - def __init__(self, tokenizer): - super().__init__() self.tokenizer = tokenizer self.input_ids = [] self.attn_masks = [] self.labels = [] - self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")['description'] + self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")[ + "description" + ] self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions]) for txt in self.txt_list: - encodings_dict = self.tokenizer('' + txt + '', - truncation=True, - max_length=self.max_length, - padding="max_length") - self.input_ids.append(torch.tensor(encodings_dict['input_ids'])) - self.attn_masks.append(torch.tensor(encodings_dict['attention_mask'])) + encodings_dict = self.tokenizer( + "" + txt + "", truncation=True, max_length=self.max_length, padding="max_length" + ) + self.input_ids.append(torch.tensor(encodings_dict["input_ids"])) + self.attn_masks.append(torch.tensor(encodings_dict["attention_mask"])) def __len__(self): return len(self.input_ids) def __getitem__(self, idx): return self.input_ids[idx], self.attn_masks[idx] - + def netflix_collator(data): - return {'input_ids': torch.stack([x[0] for x in data]), - 'attention_mask': torch.stack([x[1] for x in data]), - 'labels': torch.stack([x[0] for x in data])} + return { + "input_ids": torch.stack([x[0] for x in data]), + "attention_mask": torch.stack([x[1] for x in data]), + "labels": torch.stack([x[0] for x in data]), + } diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index 90ed10ec7cca..d16c9fdf99ad 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -35,6 +35,7 @@ def get_data(batch_size, seq_len, vocab_size): def colo_memory_cap(size_in_GB): from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) @@ -42,7 +43,6 @@ def colo_memory_cap(size_in_GB): def main(): - args = parse_benchmark_args() # Launch ColossalAI @@ -72,13 +72,13 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) @@ -101,11 +101,10 @@ def main(): start_time = time.time() for _ in range(args.max_train_steps): - input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) optimizer.zero_grad() outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False) - loss = outputs['loss'] + loss = outputs["loss"] booster.backward(loss, optimizer) optimizer.step() @@ -123,7 +122,8 @@ def main(): f"plugin: {args.plugin}, " f"throughput: {throughput}, " f"maximum memory usage per gpu: {max_mem}.", - ranks=[0]) + ranks=[0], + ) if __name__ == "__main__": diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py index 7d6bdfb9f31c..fddbc1b408e7 100644 --- a/examples/language/opt/opt_train_demo.py +++ b/examples/language/opt/opt_train_demo.py @@ -1,5 +1,3 @@ -import time - import datasets import torch import transformers @@ -12,7 +10,6 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam @@ -29,7 +26,6 @@ def move_to_cuda(batch, device): def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator): - torch.cuda.synchronize() use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 @@ -39,22 +35,19 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b model.train() optimizer.zero_grad() dataloader = iter(dataloader) - with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}]', - disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: - + with tqdm( + range(total_step), desc=f"Epoch [{epoch + 1}]", disable=not (coordinator.is_master() or is_pp_last_stage) + ) as pbar: # Forward pass for _ in pbar: if use_pipeline: - outputs = booster.execute_pipeline(dataloader, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) + outputs = booster.execute_pipeline( + dataloader, model, _criterion, optimizer, return_loss=True, return_outputs=True + ) # Backward and optimize if is_pp_last_stage: - loss = outputs['loss'] - pbar.set_postfix({'loss': loss.item()}) + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) else: data = next(dataloader) data = move_to_cuda(data) @@ -62,7 +55,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b loss = _criterion(outputs, None) # Backward booster.backward(loss, optimizer) - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) optimizer.step() optimizer.zero_grad() @@ -70,7 +63,6 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b def main(): - args = parse_demo_args() # Launch ColossalAI @@ -98,34 +90,34 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': + elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases - plugin = HybridParallelPlugin(tp_size=2, - pp_size=2, - num_microbatches=2, - enable_all_optimization=True, - zero_stage=0, - precision='fp16', - initial_scale=1) + plugin = HybridParallelPlugin( + tp_size=2, + pp_size=2, + num_microbatches=2, + enable_all_optimization=True, + zero_stage=0, + precision="fp16", + initial_scale=1, + ) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Prepare tokenizer and dataloader tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) dataset = NetflixDataset(tokenizer) - dataloader = plugin.prepare_dataloader(dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=netflix_collator) + dataloader = plugin.prepare_dataloader( + dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=netflix_collator + ) # Set optimizer optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) @@ -133,9 +125,9 @@ def main(): # Set lr scheduler total_steps = len(dataloader) * args.num_epoch num_warmup_steps = int(args.warmup_ratio * total_steps) - lr_scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=len(dataloader) * args.num_epoch) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=len(dataloader) * args.num_epoch + ) # Define criterion def _criterion(outputs, inputs): @@ -145,11 +137,9 @@ def _criterion(outputs, inputs): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _criterion, dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - dataloader=dataloader, - criterion=_criterion, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, dataloader, lr_scheduler = booster.boost( + model=model, optimizer=optimizer, dataloader=dataloader, criterion=_criterion, lr_scheduler=lr_scheduler + ) # Start finetuning logger.info(f"Start finetuning", ranks=[0]) diff --git a/examples/language/opt/run_benchmark.sh b/examples/language/opt/run_benchmark.sh index 76c5e8601989..b94ee61f277c 100644 --- a/examples/language/opt/run_benchmark.sh +++ b/examples/language/opt/run_benchmark.sh @@ -24,7 +24,7 @@ torchrun \ --mem_cap ${MEMCAP} \ --plugin ${PLUGIN} \ --batch_size ${BS} - + done done done diff --git a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py index dc4f3d856fec..17251c2f4fb3 100644 --- a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py +++ b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py @@ -11,7 +11,6 @@ def exists(val): def eval_decorator(fn): - def inner(model, *args, **kwargs): was_training = model.training model.eval() @@ -34,7 +33,6 @@ def top_k(logits, thres=0.9): class AutoregressiveWrapper(nn.Module): - def __init__(self, net, max_seq_len=2048, pad_value=0): super().__init__() self.max_seq_len = max_seq_len diff --git a/examples/language/palm/palm_pytorch/palm_pytorch.py b/examples/language/palm/palm_pytorch/palm_pytorch.py index c37974711e11..6be966d67790 100644 --- a/examples/language/palm/palm_pytorch/palm_pytorch.py +++ b/examples/language/palm/palm_pytorch/palm_pytorch.py @@ -1,14 +1,13 @@ import torch import torch.nn.functional as F from einops import rearrange -from torch import einsum, matmul, nn +from torch import matmul, nn # normalization # they use layernorm without bias, something that pytorch does not offer class LayerNorm(nn.Module): - def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps @@ -24,7 +23,6 @@ def forward(self, x): class ParallelResidual(nn.Module): - def __init__(self, *fns): super().__init__() self.fns = nn.ModuleList(fns) @@ -38,16 +36,15 @@ def forward(self, x): class RotaryEmbedding(nn.Module): - def __init__(self, dim): super().__init__() - inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, max_seq_len, *, device): seq = torch.arange(max_seq_len, device=device) - #freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) - #freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + # freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) + # freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq) freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j)) return torch.cat((freqs, freqs), dim=-1) @@ -69,7 +66,6 @@ def apply_rotary_pos_emb(pos, t): class SwiGLU(nn.Module): - def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x @@ -87,7 +83,6 @@ def FeedForward(dim, mult=4): # attention class Attention(nn.Module): - def __init__(self, dim, dim_head=64, heads=8): super().__init__() inner_dim = dim_head * heads @@ -160,7 +155,7 @@ def forward(self, x): # similarity - #sim = einsum("b h i d, b j d -> b h i j", q, k) + # sim = einsum("b h i d, b j d -> b h i j", q, k) sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2)) sim = sim.reshape(b, h, i, j) @@ -178,7 +173,7 @@ def forward(self, x): # aggregate values - #out = einsum("b h i j, b j d -> b h i d", attn, v) + # out = einsum("b h i j, b j d -> b h i d", attn, v) out = matmul(attn.reshape(b_, h_ * i_, j_), v) out = out.reshape(b_, h_, i_, d_) @@ -193,12 +188,17 @@ def forward(self, x): def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4): net = nn.Sequential( - nn.Embedding(num_tokens, dim), *[ + nn.Embedding(num_tokens, dim), + *[ ParallelResidual( Attention(dim=dim, dim_head=dim_head, heads=heads), FeedForward(dim=dim, mult=ff_mult), - ) for _ in range(depth) - ], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False)) + ) + for _ in range(depth) + ], + LayerNorm(dim), + nn.Linear(dim, num_tokens, bias=False), + ) # they used embedding weight tied projection out to logits, not common, but works net[-1].weight = net[0].weight diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 526f791403ff..e7af88c55121 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -37,7 +37,7 @@ def parse_args(): parser.add_argument( "--distplan", type=str, - default='colossalai', + default="colossalai", help="The distributed plan [colossalai, pytorch].", ) parser.add_argument( @@ -46,12 +46,14 @@ def parse_args(): default=1.0, help="Fraction of optimizer states to be offloaded. This is only used for gemini.", ) - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) parser.add_argument( "--batch_size", type=int, @@ -122,7 +124,6 @@ def generate_dataset(dummy_data: bool = False): class TextSamplerDataset(Dataset): - def __init__(self, data, seq_len): super().__init__() self.data = data @@ -130,7 +131,7 @@ def __init__(self, data, seq_len): def __getitem__(self, index): rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) - full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long() + full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() return full_seq.cuda() def __len__(self): @@ -146,18 +147,18 @@ def __len__(self): # instantiate GPT-like decoder model booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': + elif args.plugin == "gemini": plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"plugin: {plugin}") booster = Booster(plugin=plugin, **booster_kwargs) - ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext() + ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext() with ctx: model = PaLM(num_tokens=50304, dim=4096, depth=64) @@ -182,7 +183,6 @@ def __len__(self): model.train() tflops_list = [] for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): - if args.distplan == "colossalai": optimizer.zero_grad() start = time() @@ -231,12 +231,12 @@ def __len__(self): # loss = model(next(val_loader)) # print(f"validation loss: {loss.item()}") - # if i % GENERATE_EVERY == 0: - # model.eval() - # inp = random.choice(val_dataset)[:-1] - # prime = decode_tokens(inp) - # print(f"%s \n\n %s", (prime, "*" * 100)) +# if i % GENERATE_EVERY == 0: +# model.eval() +# inp = random.choice(val_dataset)[:-1] +# prime = decode_tokens(inp) +# print(f"%s \n\n %s", (prime, "*" * 100)) - # sample = model.generate(inp[None, ...], GENERATE_LENGTH) - # output_str = decode_tokens(sample[0]) - # print(output_str) +# sample = model.generate(inp[None, ...], GENERATE_LENGTH) +# output_str = decode_tokens(sample[0]) +# print(output_str) diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md index 7b5668612818..a54c7b4da3bd 100644 --- a/examples/tutorial/README.md +++ b/examples/tutorial/README.md @@ -4,7 +4,7 @@ ## Introduction -Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), +Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [NeurIPS](https://nips.cc/), [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), [NVIDIA GTC](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-S51482/) ,etc. diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py index 5a68aae18041..29101ce08434 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py @@ -20,20 +20,22 @@ def _benchmark(rank, world_size, port): only result in minor performance drop. So at last we might be able to find better training batch size for our model (combine with large batch training optimizer such as LAMB). """ - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = tm.resnet152() gm = symbolic_trace(model) raw_graph = deepcopy(gm.graph) peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048] for batch_size in batch_sizes: batch_size = int(batch_size) - gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta')) + gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device="meta")) solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95) gm.graph = solver.solve() - peak_mem, step_time = bench(gm, - torch.nn.CrossEntropyLoss(), - partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)), - num_steps=5) + peak_mem, step_time = bench( + gm, + torch.nn.CrossEntropyLoss(), + partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)), + num_steps=5, + ) peak_mems.append(peak_mem) through_puts.append(batch_size / step_time * 1.0e3) gm.graph = deepcopy(raw_graph) @@ -41,7 +43,7 @@ def _benchmark(rank, world_size, port): # print results print("===============benchmark summary================") for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts): - print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s') + print(f"batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s") def auto_activation_checkpoint_batchsize_benchmark(): diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py index aa5c47294a82..cd03a917912e 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py @@ -1,4 +1,3 @@ -import time from argparse import ArgumentParser from functools import partial @@ -8,7 +7,6 @@ from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium import colossalai -from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace, symbolic_trace from colossalai.testing import spawn @@ -19,37 +17,33 @@ def _benchmark(rank, world_size, port, args): The benchmark will sample in a range of memory budget for each model and output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory. """ - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - if args.model == 'resnet50': + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + if args.model == "resnet50": model = tm.resnet50() data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224)) gm = symbolic_trace(model) - gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta')) + gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device="meta")) loss = torch.nn.CrossEntropyLoss() else: model = gpt2_medium() data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257) - data, mask = data_gen(device='meta')[0] - gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + data, mask = data_gen(device="meta")[0] + gm = symbolic_trace(model, meta_args={"input_ids": data, "attention_mask": mask}) gm = metainfo_trace(gm, data, mask) loss = GPTLMLoss() - free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2 - start_factor = 4 if args.model == 'resnet50' else 10 + free_memory = 11000 * 1024**2 if args.model == "resnet50" else 56000 * 1024**2 + start_factor = 4 if args.model == "resnet50" else 10 # trace and benchmark - budgets, peak_hist, step_hist = bench_rotor(gm, - loss, - data_gen, - num_steps=5, - sample_points=15, - free_memory=free_memory, - start_factor=start_factor) + budgets, peak_hist, step_hist = bench_rotor( + gm, loss, data_gen, num_steps=5, sample_points=15, free_memory=free_memory, start_factor=start_factor + ) # print summary print("==============benchmark summary==============") for budget, peak, step in zip(budgets, peak_hist, step_hist): - print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') + print(f"memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS") # plot valid results fig, axs = plt.subplots(1, 2, figsize=(16, 8)) @@ -57,14 +51,14 @@ def _benchmark(rank, world_size, port, args): # plot peak memory vs. budget memory axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) - axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--') + axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle="--") axs[0].set_xlabel("Budget Memory (MB)") axs[0].set_ylabel("Peak Memory (MB)") axs[0].set_title("Peak Memory vs. Budget Memory") # plot relative step time vs. budget memory axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) - axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--') + axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle="--") axs[1].set_xlabel("Peak Memory (MB)") axs[1].set_ylabel("Relative Step Time") axs[1].set_title("Step Time vs. Peak Memory") @@ -81,7 +75,7 @@ def auto_activation_checkpoint_benchmark(args): if __name__ == "__main__": parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark") - parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50']) + parser.add_argument("--model", type=str, default="gpt2", choices=["gpt2", "resnet50"]) args = parser.parse_args() auto_activation_checkpoint_benchmark(args) diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py index 33aa5990f7c1..3c5b786b561a 100644 --- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -17,14 +17,14 @@ def synthesize_data(): def main(): - colossalai.launch_from_torch(config='./config.py') + colossalai.launch_from_torch(config="./config.py") logger = get_dist_logger() # trace the model with meta data model = resnet50(num_classes=10).cuda() - input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} + input_sample = {"x": torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to("meta")} device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True) model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True) @@ -88,8 +88,9 @@ def main(): logger.info( f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", - ranks=[0]) + ranks=[0], + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/auto_parallel/bench_utils.py b/examples/tutorial/auto_parallel/bench_utils.py index 69859f885ae6..96cfd49c6787 100644 --- a/examples/tutorial/auto_parallel/bench_utils.py +++ b/examples/tutorial/auto_parallel/bench_utils.py @@ -1,22 +1,19 @@ import time from copy import deepcopy -from functools import partial from typing import Callable, Tuple import numpy as np import torch import torch.nn as nn -import torchvision.models as tm from transformers import GPT2Config, GPT2LMHeadModel from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace -def bench(gm: torch.fx.GraphModule, - criterion: torch.nn.Module, - data_gen: Callable, - num_steps: int = 5) -> Tuple[int, int]: +def bench( + gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5 +) -> Tuple[int, int]: """Benchmarking a given graph module Args: gm (torch.fx.GraphModule): The graph module to benchmark. @@ -28,7 +25,7 @@ def bench(gm: torch.fx.GraphModule, """ gm.train() gm.cuda() - step_time = float('inf') + step_time = float("inf") torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -58,13 +55,15 @@ def bench(gm: torch.fx.GraphModule, return peak_mem, step_time * 1.0e3 -def bench_rotor(gm: torch.fx.GraphModule, - criterion: torch.nn.Module, - data_gen: Callable, - num_steps: int = 5, - sample_points: int = 20, - free_memory: int = torch.cuda.mem_get_info()[0], - start_factor: int = 4) -> Tuple[np.array, list, list]: +def bench_rotor( + gm: torch.fx.GraphModule, + criterion: torch.nn.Module, + data_gen: Callable, + num_steps: int = 5, + sample_points: int = 20, + free_memory: int = torch.cuda.mem_get_info()[0], + start_factor: int = 4, +) -> Tuple[np.array, list, list]: """Auto Checkpoint Rotor Algorithm benchmarking Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data. Args: @@ -88,7 +87,7 @@ def bench_rotor(gm: torch.fx.GraphModule, gm.graph = solver.solve(verbose=False) peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps) except: - peak_memory, step_time = budget / 1024**2, float('inf') + peak_memory, step_time = budget / 1024**2, float("inf") peak_hist.append(peak_memory) step_hist.append(step_time) gm.graph = deepcopy(raw_graph) @@ -100,22 +99,27 @@ class GPTLMModel(nn.Module): GPT Model """ - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) if checkpoint: self.model.gradient_checkpointing_enable() @@ -152,7 +156,7 @@ def gpt2_6b(checkpoint=False): return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) -def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): +def data_gen_gpt2(batch_size, seq_len, vocab_size, device="cuda:0"): """ Generate random data for gpt2 benchmarking """ @@ -161,7 +165,7 @@ def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'): return (input_ids, attention_mask), attention_mask -def data_gen_resnet(batch_size, shape, device='cuda:0'): +def data_gen_resnet(batch_size, shape, device="cuda:0"): """ Generate random data for resnet benchmarking """ diff --git a/examples/tutorial/auto_parallel/setup.py b/examples/tutorial/auto_parallel/setup.py index 6e6cff32ed23..94d5ec0c0e9e 100644 --- a/examples/tutorial/auto_parallel/setup.py +++ b/examples/tutorial/auto_parallel/setup.py @@ -1,13 +1,13 @@ from setuptools import find_packages, setup setup( - name='auto_parallel', - version='0.0.1', - description='', + name="auto_parallel", + version="0.0.1", + description="", packages=find_packages(), install_requires=[ - 'torch', - 'numpy', - 'tqdm', + "torch", + "numpy", + "tqdm", ], ) diff --git a/examples/tutorial/download_cifar10.py b/examples/tutorial/download_cifar10.py index 5c6b6988ade5..78ea3d1e062e 100644 --- a/examples/tutorial/download_cifar10.py +++ b/examples/tutorial/download_cifar10.py @@ -5,9 +5,9 @@ def main(): dir_path = os.path.dirname(os.path.realpath(__file__)) - data_root = os.path.join(dir_path, 'data') + data_root = os.path.join(dir_path, "data") dataset = CIFAR10(root=data_root, download=True) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py index 287f62aa7a90..15f9d0bc75ee 100644 --- a/examples/tutorial/hybrid_parallel/config.py +++ b/examples/tutorial/hybrid_parallel/config.py @@ -18,11 +18,11 @@ MLP_RATIO = 2 NUM_CLASSES = 10 CHECKPOINT = False -SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token +SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token # parallel setting TENSOR_PARALLEL_SIZE = 2 -TENSOR_PARALLEL_MODE = '1d' +TENSOR_PARALLEL_MODE = "1d" parallel = dict( pipeline=2, @@ -33,4 +33,4 @@ clip_grad_norm = 1.0 # pipeline config -NUM_MICRO_BATCHES = parallel['pipeline'] +NUM_MICRO_BATCHES = parallel["pipeline"] diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py index 21a568168e33..95f1bf8ee17c 100644 --- a/examples/tutorial/hybrid_parallel/train.py +++ b/examples/tutorial/hybrid_parallel/train.py @@ -14,8 +14,7 @@ from colossalai.utils import is_using_pp -class DummyDataloader(): - +class DummyDataloader: def __init__(self, length, batch_size): self.length = length self.batch_size = batch_size @@ -50,7 +49,7 @@ def main(): logger = get_dist_logger() logger.info("initialized distributed environment", ranks=[0]) - if hasattr(gpc.config, 'LOG_PATH'): + if hasattr(gpc.config, "LOG_PATH"): if gpc.get_global_rank() == 0: log_path = gpc.config.LOG_PATH if not os.path.exists(log_path): @@ -60,15 +59,17 @@ def main(): use_pipeline = is_using_pp() # create model - model_kwargs = dict(img_size=gpc.config.IMG_SIZE, - patch_size=gpc.config.PATCH_SIZE, - hidden_size=gpc.config.HIDDEN_SIZE, - depth=gpc.config.DEPTH, - num_heads=gpc.config.NUM_HEADS, - mlp_ratio=gpc.config.MLP_RATIO, - num_classes=10, - init_method='jax', - checkpoint=gpc.config.CHECKPOINT) + model_kwargs = dict( + img_size=gpc.config.IMG_SIZE, + patch_size=gpc.config.PATCH_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + depth=gpc.config.DEPTH, + num_heads=gpc.config.NUM_HEADS, + mlp_ratio=gpc.config.MLP_RATIO, + num_classes=10, + init_method="jax", + checkpoint=gpc.config.CHECKPOINT, + ) if use_pipeline: pipelinable = PipelinableContext() @@ -102,16 +103,18 @@ def main(): optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=gpc.config.NUM_EPOCHS, - warmup_steps=gpc.config.WARMUP_EPOCHS) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS + ) # initialize - engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader) + engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + ) logger.info("Engine is built", ranks=[0]) @@ -121,7 +124,7 @@ def main(): data_iter = iter(train_dataloader) if gpc.get_global_rank() == 0: - description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS) progress = tqdm(range(len(train_dataloader)), desc=description) else: progress = range(len(train_dataloader)) @@ -133,5 +136,5 @@ def main(): gpc.destroy() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py index 6ebd8d68083d..dd114b5af86d 100644 --- a/examples/tutorial/large_batch_optimizer/train.py +++ b/examples/tutorial/large_batch_optimizer/train.py @@ -10,8 +10,7 @@ from colossalai.nn.optimizer import Lamb, Lars -class DummyDataloader(): - +class DummyDataloader: def __init__(self, length, batch_size): self.length = length self.batch_size = batch_size @@ -39,10 +38,9 @@ def __len__(self): def main(): # initialize distributed setting parser = colossalai.get_default_parser() - parser.add_argument('--optimizer', - choices=['lars', 'lamb'], - help="Choose your large-batch optimizer", - required=True) + parser.add_argument( + "--optimizer", choices=["lars", "lamb"], help="Choose your large-batch optimizer", required=True + ) args = parser.parse_args() # launch from torch @@ -70,16 +68,18 @@ def main(): optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, - total_steps=gpc.config.NUM_EPOCHS, - warmup_steps=gpc.config.WARMUP_EPOCHS) + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=gpc.config.WARMUP_EPOCHS + ) # initialize - engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader, - test_dataloader=test_dataloader) + engine, train_dataloader, test_dataloader, _ = colossalai.initialize( + model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + ) logger.info("Engine is built", ranks=[0]) @@ -89,7 +89,7 @@ def main(): data_iter = iter(train_dataloader) if gpc.get_global_rank() == 0: - description = 'Epoch {} / {}'.format(epoch, gpc.config.NUM_EPOCHS) + description = "Epoch {} / {}".format(epoch, gpc.config.NUM_EPOCHS) progress = tqdm(range(len(train_dataloader)), desc=description) else: progress = range(len(train_dataloader)) @@ -100,5 +100,5 @@ def main(): lr_scheduler.step() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/new_api/cifar_resnet/eval.py b/examples/tutorial/new_api/cifar_resnet/eval.py index 657708ec3ff2..526e41a2850f 100644 --- a/examples/tutorial/new_api/cifar_resnet/eval.py +++ b/examples/tutorial/new_api/cifar_resnet/eval.py @@ -1,7 +1,6 @@ import argparse import torch -import torch.nn as nn import torchvision import torchvision.transforms as transforms @@ -9,15 +8,15 @@ # Parse Arguments # ============================== parser = argparse.ArgumentParser() -parser.add_argument('-e', '--epoch', type=int, default=80, help="resume from the epoch's checkpoint") -parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") +parser.add_argument("-e", "--epoch", type=int, default=80, help="resume from the epoch's checkpoint") +parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") args = parser.parse_args() # ============================== # Prepare Test Dataset # ============================== # CIFAR-10 dataset -test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor()) +test_dataset = torchvision.datasets.CIFAR10(root="./data/", train=False, transform=transforms.ToTensor()) # Data loader test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False) @@ -26,7 +25,7 @@ # Load Model # ============================== model = torchvision.models.resnet18(num_classes=10).cuda() -state_dict = torch.load(f'{args.checkpoint}/model_{args.epoch}.pth') +state_dict = torch.load(f"{args.checkpoint}/model_{args.epoch}.pth") model.load_state_dict(state_dict) # ============================== @@ -45,4 +44,4 @@ total += labels.size(0) correct += (predicted == labels).sum().item() - print('Accuracy of the model on the test images: {} %'.format(100 * correct / total)) + print("Accuracy of the model on the test images: {} %".format(100 * correct / total)) diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py index fe0dabf08377..6ae2d8b0412f 100644 --- a/examples/tutorial/new_api/cifar_resnet/train.py +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -30,23 +30,19 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): # transform transform_train = transforms.Compose( - [transforms.Pad(4), - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32), - transforms.ToTensor()]) + [transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()] + ) transform_test = transforms.ToTensor() # CIFAR-10 dataset - data_path = os.environ.get('DATA', './data') + data_path = os.environ.get("DATA", "./data") with coordinator.priority_execution(): - train_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=True, - transform=transform_train, - download=True) - test_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=False, - transform=transform_test, - download=True) + train_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=True, transform=transform_train, download=True + ) + test_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=False, transform=transform_test, download=True + ) # Data loader train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) @@ -70,14 +66,21 @@ def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoo dist.all_reduce(total) accuracy = correct.item() / total.item() if coordinator.is_master(): - print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %') + print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %") return accuracy -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: nn.Module, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: for images, labels in pbar: images = images.cuda() labels = labels.cuda() @@ -91,7 +94,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: n optimizer.zero_grad() # Print log info - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) def main(): @@ -100,19 +103,20 @@ def main(): # ============================== parser = argparse.ArgumentParser() # FIXME(ver217): gemini is not supported resnet now - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], - help="plugin to use") - parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") - parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") - parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") - parser.add_argument('--target_acc', - type=float, - default=None, - help="target accuracy. Raise exception if not reached") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument("-r", "--resume", type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") + parser.add_argument("-i", "--interval", type=int, default=5, help="interval of saving checkpoint") + parser.add_argument( + "--target_acc", type=float, default=None, help="target accuracy. Raise exception if not reached" + ) args = parser.parse_args() # ============================== @@ -136,13 +140,13 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -168,18 +172,17 @@ def main(): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, criterion, _, lr_scheduler = booster.boost(model, - optimizer, - criterion=criterion, - lr_scheduler=lr_scheduler) + model, optimizer, criterion, _, lr_scheduler = booster.boost( + model, optimizer, criterion=criterion, lr_scheduler=lr_scheduler + ) # ============================== # Resume from checkpoint # ============================== if args.resume >= 0: - booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') - booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') - booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + booster.load_model(model, f"{args.checkpoint}/model_{args.resume}.pth") + booster.load_optimizer(optimizer, f"{args.checkpoint}/optimizer_{args.resume}.pth") + booster.load_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{args.resume}.pth") # ============================== # Train model @@ -191,14 +194,14 @@ def main(): # save checkpoint if args.interval > 0 and (epoch + 1) % args.interval == 0: - booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') - booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') - booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') + booster.save_model(model, f"{args.checkpoint}/model_{epoch + 1}.pth") + booster.save_optimizer(optimizer, f"{args.checkpoint}/optimizer_{epoch + 1}.pth") + booster.save_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{epoch + 1}.pth") accuracy = evaluate(model, test_dataloader, coordinator) if args.target_acc is not None: - assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}' + assert accuracy >= args.target_acc, f"Accuracy {accuracy} is lower than target accuracy {args.target_acc}" -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py index 82a8f2ed97e4..226a4b320961 100644 --- a/examples/tutorial/new_api/cifar_vit/train.py +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -32,35 +32,37 @@ def vit_cifar(**kwargs): pretrained_cfg = _cfg(num_classes=10, input_size=(3, 32, 32), crop_pct=1.0) model_kwargs = dict(patch_size=4, embed_dim=512, depth=6, num_heads=8, drop_rate=0.1, mlp_ratio=1.0, **kwargs) - model = _create_vision_transformer('vit_cifar', pretrained_cfg=pretrained_cfg, **model_kwargs) + model = _create_vision_transformer("vit_cifar", pretrained_cfg=pretrained_cfg, **model_kwargs) return model def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): # transform - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), - ]) - transform_test = transforms.Compose([ - transforms.Resize(32), - transforms.ToTensor(), - transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), - ]) + transform_train = transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), + ] + ) + transform_test = transforms.Compose( + [ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), + ] + ) # CIFAR-10 dataset - data_path = os.environ.get('DATA', './data') + data_path = os.environ.get("DATA", "./data") with coordinator.priority_execution(): - train_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=True, - transform=transform_train, - download=True) - test_dataset = torchvision.datasets.CIFAR10(root=data_path, - train=False, - transform=transform_test, - download=True) + train_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=True, transform=transform_train, download=True + ) + test_dataset = torchvision.datasets.CIFAR10( + root=data_path, train=False, transform=transform_test, download=True + ) # Data loader train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) @@ -84,14 +86,21 @@ def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoo dist.all_reduce(total) accuracy = correct.item() / total.item() if coordinator.is_master(): - print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %') + print(f"Accuracy of the model on the test images: {accuracy * 100:.2f} %") return accuracy -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + criterion: nn.Module, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: for images, labels in pbar: images = images.cuda() labels = labels.cuda() @@ -105,7 +114,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: n optimizer.zero_grad() # Print log info - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) def main(): @@ -114,19 +123,20 @@ def main(): # ============================== parser = argparse.ArgumentParser() # FIXME(ver217): gemini is not supported resnet now - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], - help="plugin to use") - parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") - parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") - parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") - parser.add_argument('--target_acc', - type=float, - default=None, - help="target accuracy. Raise exception if not reached") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument("-r", "--resume", type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument("-c", "--checkpoint", type=str, default="./checkpoint", help="checkpoint directory") + parser.add_argument("-i", "--interval", type=int, default=5, help="interval of saving checkpoint") + parser.add_argument( + "--target_acc", type=float, default=None, help="target accuracy. Raise exception if not reached" + ) args = parser.parse_args() # ============================== @@ -150,13 +160,13 @@ def main(): # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -182,19 +192,17 @@ def main(): # ============================== # Boost with ColossalAI # ============================== - model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, - optimizer, - criterion=criterion, - dataloader=train_dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost( + model, optimizer, criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler + ) # ============================== # Resume from checkpoint # ============================== if args.resume >= 0: - booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') - booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') - booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + booster.load_model(model, f"{args.checkpoint}/model_{args.resume}.pth") + booster.load_optimizer(optimizer, f"{args.checkpoint}/optimizer_{args.resume}.pth") + booster.load_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{args.resume}.pth") # ============================== # Train model @@ -206,14 +214,14 @@ def main(): # save checkpoint if args.interval > 0 and (epoch + 1) % args.interval == 0: - booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') - booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') - booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') + booster.save_model(model, f"{args.checkpoint}/model_{epoch + 1}.pth") + booster.save_optimizer(optimizer, f"{args.checkpoint}/optimizer_{epoch + 1}.pth") + booster.save_lr_scheduler(lr_scheduler, f"{args.checkpoint}/lr_scheduler_{epoch + 1}.pth") accuracy = evaluate(model, test_dataloader, coordinator) if args.target_acc is not None: - assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}' + assert accuracy >= args.target_acc, f"Accuracy {accuracy} is lower than target accuracy {args.target_acc}" -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/new_api/glue_bert/data.py b/examples/tutorial/new_api/glue_bert/data.py index 981cedcca8c2..ef51f938dc4f 100644 --- a/examples/tutorial/new_api/glue_bert/data.py +++ b/examples/tutorial/new_api/glue_bert/data.py @@ -5,7 +5,6 @@ class GLUEDataBuilder: - task_text_field_map = { "cola": ["sentence"], "sst2": ["sentence"], @@ -84,10 +83,9 @@ def prepare_data(self): AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) def train_dataloader(self): - return self.plugin.prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) + return self.plugin.prepare_dataloader( + self.dataset["train"], batch_size=self.train_batch_size, shuffle=True, drop_last=True + ) def val_dataloader(self): if len(self.eval_splits) == 1: @@ -108,7 +106,6 @@ def test_dataloader(self): ] def convert_to_features(self, example_batch): - # Either encode single sentence or sentence pairs if len(self.text_fields) > 1: texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) @@ -116,10 +113,9 @@ def convert_to_features(self, example_batch): texts_or_text_pairs = example_batch[self.text_fields[0]] # Tokenize the text/text pairs - features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, - max_length=self.max_seq_length, - padding='max_length', - truncation=True) + features = self.tokenizer.batch_encode_plus( + texts_or_text_pairs, max_length=self.max_seq_length, padding="max_length", truncation=True + ) # Rename label to labels to make it easier to pass to model forward features["labels"] = example_batch["label"] diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py index 63bdfc5d02cf..7d69dbc066b3 100644 --- a/examples/tutorial/new_api/glue_bert/finetune.py +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -33,8 +33,14 @@ def move_to_cuda(batch): @torch.no_grad() -def evaluate(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str, - eval_splits: List[str], coordinator: DistCoordinator): +def evaluate( + model: nn.Module, + test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, + task_name: str, + eval_splits: List[str], + coordinator: DistCoordinator, +): metric = datasets.load_metric("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) model.eval() @@ -58,7 +64,7 @@ def evaluate_subset(dataloader: DataLoader): results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) if coordinator.is_master(): - results['loss'] = accum_loss.item() / coordinator.world_size + results["loss"] = accum_loss.item() / coordinator.world_size return results if isinstance(test_dataloader, DataLoader): @@ -68,14 +74,21 @@ def evaluate_subset(dataloader: DataLoader): final_results = {} for split, sub_loader in zip(eval_splits, test_dataloader): results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) + final_results.update({f"{k}_{split}": v for k, v in results.items()}) return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader, - booster: Booster, coordinator: DistCoordinator): +def train_epoch( + epoch: int, + model: nn.Module, + optimizer: Optimizer, + lr_scheduler, + train_dataloader: DataLoader, + booster: Booster, + coordinator: DistCoordinator, +): model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + with tqdm(train_dataloader, desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not coordinator.is_master()) as pbar: for batch in pbar: # Forward pass batch = move_to_cuda(batch) @@ -89,7 +102,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler lr_scheduler.step() # Print log info - pbar.set_postfix({'loss': loss.item()}) + pbar.set_postfix({"loss": loss.item()}) def main(): @@ -97,14 +110,16 @@ def main(): # Parse Arguments # ============================== parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], - help="plugin to use") - parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="torch_ddp", + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero"], + help="plugin to use", + ) + parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") args = parser.parse_args() # ============================== @@ -115,19 +130,19 @@ def main(): # local_batch_size = BATCH_SIZE // coordinator.world_size lr = LEARNING_RATE * coordinator.world_size - model_name = 'bert-base-uncased' + model_name = "bert-base-uncased" # ============================== # Instantiate Plugin and Booster # ============================== booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): + if args.plugin == "torch_ddp_fp16": + booster_kwargs["mixed_precision"] = "fp16" + if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': + elif args.plugin == "gemini": + plugin = GeminiPlugin(placement_policy="cuda", strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) @@ -135,11 +150,9 @@ def main(): # ============================== # Prepare Dataloader # ============================== - data_builder = GLUEDataBuilder(model_name, - plugin, - args.task, - train_batch_size=BATCH_SIZE, - eval_batch_size=BATCH_SIZE) + data_builder = GLUEDataBuilder( + model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE + ) train_dataloader = data_builder.train_dataloader() test_dataloader = data_builder.test_dataloader() @@ -185,14 +198,15 @@ def main(): for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) - results = evaluate(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, - coordinator) + results = evaluate( + model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, coordinator + ) if coordinator.is_master(): print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + if args.target_f1 is not None and "f1" in results: + assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/opt/inference/batch.py b/examples/tutorial/opt/inference/batch.py index 1a0876ca8338..e4e857b264a0 100644 --- a/examples/tutorial/opt/inference/batch.py +++ b/examples/tutorial/opt/inference/batch.py @@ -1,5 +1,6 @@ +from typing import Any, Deque, Hashable, List, Tuple + import torch -from typing import List, Deque, Tuple, Hashable, Any from energonai import BatchManager, SubmitEntry, TaskEntry @@ -10,15 +11,15 @@ def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None: self.pad_token_id = pad_token_id def _left_padding(self, batch_inputs): - max_len = max(len(inputs['input_ids']) for inputs in batch_inputs) - outputs = {'input_ids': [], 'attention_mask': []} + max_len = max(len(inputs["input_ids"]) for inputs in batch_inputs) + outputs = {"input_ids": [], "attention_mask": []} for inputs in batch_inputs: - input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask'] + input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"] padding_len = max_len - len(input_ids) input_ids = [self.pad_token_id] * padding_len + input_ids attention_mask = [0] * padding_len + attention_mask - outputs['input_ids'].append(input_ids) - outputs['attention_mask'].append(attention_mask) + outputs["input_ids"].append(input_ids) + outputs["attention_mask"].append(attention_mask) for k in outputs: outputs[k] = torch.tensor(outputs[k]) return outputs, max_len @@ -26,7 +27,7 @@ def _left_padding(self, batch_inputs): @staticmethod def _make_batch_key(entry: SubmitEntry) -> tuple: data = entry.data - return (data['top_k'], data['top_p'], data['temperature']) + return (data["top_k"], data["top_p"], data["temperature"]) def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: entry = q.popleft() @@ -37,7 +38,7 @@ def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: break if self._make_batch_key(entry) != self._make_batch_key(q[0]): break - if q[0].data['max_tokens'] > entry.data['max_tokens']: + if q[0].data["max_tokens"] > entry.data["max_tokens"]: break e = q.popleft() batch.append(e.data) @@ -45,12 +46,12 @@ def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]: inputs, max_len = self._left_padding(batch) trunc_lens = [] for data in batch: - trunc_lens.append(max_len + data['max_tokens']) - inputs['top_k'] = entry.data['top_k'] - inputs['top_p'] = entry.data['top_p'] - inputs['temperature'] = entry.data['temperature'] - inputs['max_tokens'] = max_len + entry.data['max_tokens'] - return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens} + trunc_lens.append(max_len + data["max_tokens"]) + inputs["top_k"] = entry.data["top_k"] + inputs["top_p"] = entry.data["top_p"] + inputs["temperature"] = entry.data["temperature"] + inputs["max_tokens"] = max_len + entry.data["max_tokens"] + return TaskEntry(tuple(uids), inputs), {"trunc_lens": trunc_lens} def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]: retval = [] diff --git a/examples/tutorial/opt/inference/benchmark/locustfile.py b/examples/tutorial/opt/inference/benchmark/locustfile.py index 4d829e5d83bf..76ef9d8cb3d6 100644 --- a/examples/tutorial/opt/inference/benchmark/locustfile.py +++ b/examples/tutorial/opt/inference/benchmark/locustfile.py @@ -1,15 +1,14 @@ from locust import HttpUser, task -from json import JSONDecodeError class GenerationUser(HttpUser): @task def generate(self): - prompt = 'Question: What is the longest river on the earth? Answer:' + prompt = "Question: What is the longest river on the earth? Answer:" for i in range(4, 9): - data = {'max_tokens': 2**i, 'prompt': prompt} - with self.client.post('/generation', json=data, catch_response=True) as response: + data = {"max_tokens": 2**i, "prompt": prompt} + with self.client.post("/generation", json=data, catch_response=True) as response: if response.status_code in (200, 406): response.success() else: - response.failure('Response wrong') + response.failure("Response wrong") diff --git a/examples/tutorial/opt/inference/cache.py b/examples/tutorial/opt/inference/cache.py index 30febc44fbb3..1eb7dac2ea04 100644 --- a/examples/tutorial/opt/inference/cache.py +++ b/examples/tutorial/opt/inference/cache.py @@ -1,7 +1,7 @@ from collections import OrderedDict -from threading import Lock from contextlib import contextmanager -from typing import List, Any, Hashable, Dict +from threading import Lock +from typing import Any, Dict, Hashable, List class MissCacheError(Exception): diff --git a/examples/tutorial/opt/inference/opt_fastapi.py b/examples/tutorial/opt/inference/opt_fastapi.py index cbfc2a22e7c0..6475284e535b 100644 --- a/examples/tutorial/opt/inference/opt_fastapi.py +++ b/examples/tutorial/opt/inference/opt_fastapi.py @@ -4,20 +4,21 @@ from typing import Optional import uvicorn +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError from energonai import QueueFullError, launch_engine from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel, Field from transformers import GPT2Tokenizer -from batch import BatchManagerForGeneration -from cache import ListCache, MissCacheError - class GenerationTaskReq(BaseModel): max_tokens: int = Field(gt=0, le=256, example=64) prompt: str = Field( - min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + min_length=1, + example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:", + ) top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) @@ -26,7 +27,7 @@ class GenerationTaskReq(BaseModel): app = FastAPI() -@app.post('/generation') +@app.post("/generation") async def generate(data: GenerationTaskReq, request: Request): logger.info(f'{request.client.host}:{request.client.port} - "{request.method} {request.url.path}" - {data}') key = (data.prompt, data.max_tokens) @@ -35,13 +36,13 @@ async def generate(data: GenerationTaskReq, request: Request): raise MissCacheError() outputs = cache.get(key) output = random.choice(outputs) - logger.info('Cache hit') + logger.info("Cache hit") except MissCacheError: inputs = tokenizer(data.prompt, truncation=True, max_length=512) - inputs['max_tokens'] = data.max_tokens - inputs['top_k'] = data.top_k - inputs['top_p'] = data.top_p - inputs['temperature'] = data.temperature + inputs["max_tokens"] = data.max_tokens + inputs["top_k"] = data.top_k + inputs["top_p"] = data.top_p + inputs["temperature"] = data.temperature try: uid = id(data) engine.submit(uid, inputs) @@ -52,7 +53,7 @@ async def generate(data: GenerationTaskReq, request: Request): except QueueFullError as e: raise HTTPException(status_code=406, detail=e.args[0]) - return {'text': output} + return {"text": output} @app.on_event("shutdown") @@ -64,60 +65,72 @@ async def shutdown(*_): def get_model_fn(model_name: str): - model_map = { - 'opt-125m': opt_125M, - 'opt-6.7b': opt_6B, - 'opt-30b': opt_30B, - 'opt-175b': opt_175B - } + model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B} return model_map[model_name] def print_args(args: argparse.Namespace): - print('\n==> Args:') + print("\n==> Args:") for k, v in args.__dict__.items(): - print(f'{k} = {v}') + print(f"{k} = {v}") FIXED_CACHE_KEYS = [ - ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), - ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), - ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) + ( + "Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:", + 64, + ), + ( + "A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.", + 64, + ), + ( + "English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", + 64, + ), ] -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) - parser.add_argument('--tp', type=int, default=1) - parser.add_argument('--master_host', default='localhost') - parser.add_argument('--master_port', type=int, default=19990) - parser.add_argument('--rpc_port', type=int, default=19980) - parser.add_argument('--max_batch_size', type=int, default=8) - parser.add_argument('--pipe_size', type=int, default=1) - parser.add_argument('--queue_size', type=int, default=0) - parser.add_argument('--http_host', default='0.0.0.0') - parser.add_argument('--http_port', type=int, default=7070) - parser.add_argument('--checkpoint', default=None) - parser.add_argument('--cache_size', type=int, default=0) - parser.add_argument('--cache_list_size', type=int, default=1) + parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"]) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--master_host", default="localhost") + parser.add_argument("--master_port", type=int, default=19990) + parser.add_argument("--rpc_port", type=int, default=19980) + parser.add_argument("--max_batch_size", type=int, default=8) + parser.add_argument("--pipe_size", type=int, default=1) + parser.add_argument("--queue_size", type=int, default=0) + parser.add_argument("--http_host", default="0.0.0.0") + parser.add_argument("--http_port", type=int, default=7070) + parser.add_argument("--checkpoint", default=None) + parser.add_argument("--cache_size", type=int, default=0) + parser.add_argument("--cache_list_size", type=int, default=1) args = parser.parse_args() print_args(args) model_kwargs = {} if args.checkpoint is not None: - model_kwargs['checkpoint'] = args.checkpoint + model_kwargs["checkpoint"] = args.checkpoint logger = logging.getLogger(__name__) - tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b") if args.cache_size > 0: cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) else: cache = None - engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), - batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, - pad_token_id=tokenizer.pad_token_id), - pipe_size=args.pipe_size, - queue_size=args.queue_size, - **model_kwargs) + engine = launch_engine( + args.tp, + 1, + args.master_host, + args.master_port, + args.rpc_port, + get_model_fn(args.model), + batch_manager=BatchManagerForGeneration( + max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id + ), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs, + ) config = uvicorn.Config(app, host=args.http_host, port=args.http_port) server = uvicorn.Server(config=config) server.run() diff --git a/examples/tutorial/opt/inference/opt_server.py b/examples/tutorial/opt/inference/opt_server.py index 8dab82622c59..7f591b9be111 100644 --- a/examples/tutorial/opt/inference/opt_server.py +++ b/examples/tutorial/opt/inference/opt_server.py @@ -1,33 +1,36 @@ -import logging import argparse +import logging import random -from torch import Tensor -from pydantic import BaseModel, Field from typing import Optional -from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B -from transformers import GPT2Tokenizer -from energonai import launch_engine, QueueFullError + +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError +from energonai import QueueFullError, launch_engine +from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B +from pydantic import BaseModel, Field from sanic import Sanic from sanic.request import Request from sanic.response import json -from sanic_ext import validate, openapi -from batch import BatchManagerForGeneration -from cache import ListCache, MissCacheError +from sanic_ext import openapi, validate +from torch import Tensor +from transformers import GPT2Tokenizer class GenerationTaskReq(BaseModel): max_tokens: int = Field(gt=0, le=256, example=64) prompt: str = Field( - min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + min_length=1, + example="Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:", + ) top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) -app = Sanic('opt') +app = Sanic("opt") -@app.post('/generation') +@app.post("/generation") @openapi.body(GenerationTaskReq) @validate(json=GenerationTaskReq) async def generate(request: Request, body: GenerationTaskReq): @@ -38,13 +41,13 @@ async def generate(request: Request, body: GenerationTaskReq): raise MissCacheError() outputs = cache.get(key) output = random.choice(outputs) - logger.info('Cache hit') + logger.info("Cache hit") except MissCacheError: inputs = tokenizer(body.prompt, truncation=True, max_length=512) - inputs['max_tokens'] = body.max_tokens - inputs['top_k'] = body.top_k - inputs['top_p'] = body.top_p - inputs['temperature'] = body.temperature + inputs["max_tokens"] = body.max_tokens + inputs["top_k"] = body.top_k + inputs["top_p"] = body.top_p + inputs["temperature"] = body.temperature try: uid = id(body) engine.submit(uid, inputs) @@ -54,9 +57,9 @@ async def generate(request: Request, body: GenerationTaskReq): if cache is not None: cache.add(key, output) except QueueFullError as e: - return json({'detail': e.args[0]}, status=406) + return json({"detail": e.args[0]}, status=406) - return json({'text': output}) + return json({"text": output}) @app.after_server_stop @@ -65,58 +68,70 @@ def shutdown(*_): def get_model_fn(model_name: str): - model_map = { - 'opt-125m': opt_125M, - 'opt-6.7b': opt_6B, - 'opt-30b': opt_30B, - 'opt-175b': opt_175B - } + model_map = {"opt-125m": opt_125M, "opt-6.7b": opt_6B, "opt-30b": opt_30B, "opt-175b": opt_175B} return model_map[model_name] def print_args(args: argparse.Namespace): - print('\n==> Args:') + print("\n==> Args:") for k, v in args.__dict__.items(): - print(f'{k} = {v}') + print(f"{k} = {v}") FIXED_CACHE_KEYS = [ - ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), - ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), - ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) + ( + "Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:", + 64, + ), + ( + "A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.", + 64, + ), + ( + "English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", + 64, + ), ] -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('model', choices=['opt-125m', 'opt-6.7b', 'opt-30b', 'opt-175b']) - parser.add_argument('--tp', type=int, default=1) - parser.add_argument('--master_host', default='localhost') - parser.add_argument('--master_port', type=int, default=19990) - parser.add_argument('--rpc_port', type=int, default=19980) - parser.add_argument('--max_batch_size', type=int, default=8) - parser.add_argument('--pipe_size', type=int, default=1) - parser.add_argument('--queue_size', type=int, default=0) - parser.add_argument('--http_host', default='0.0.0.0') - parser.add_argument('--http_port', type=int, default=7070) - parser.add_argument('--checkpoint', default=None) - parser.add_argument('--cache_size', type=int, default=0) - parser.add_argument('--cache_list_size', type=int, default=1) + parser.add_argument("model", choices=["opt-125m", "opt-6.7b", "opt-30b", "opt-175b"]) + parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--master_host", default="localhost") + parser.add_argument("--master_port", type=int, default=19990) + parser.add_argument("--rpc_port", type=int, default=19980) + parser.add_argument("--max_batch_size", type=int, default=8) + parser.add_argument("--pipe_size", type=int, default=1) + parser.add_argument("--queue_size", type=int, default=0) + parser.add_argument("--http_host", default="0.0.0.0") + parser.add_argument("--http_port", type=int, default=7070) + parser.add_argument("--checkpoint", default=None) + parser.add_argument("--cache_size", type=int, default=0) + parser.add_argument("--cache_list_size", type=int, default=1) args = parser.parse_args() print_args(args) model_kwargs = {} if args.checkpoint is not None: - model_kwargs['checkpoint'] = args.checkpoint + model_kwargs["checkpoint"] = args.checkpoint logger = logging.getLogger(__name__) - tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-30b') + tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b") if args.cache_size > 0: cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) else: cache = None - engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), - batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, - pad_token_id=tokenizer.pad_token_id), - pipe_size=args.pipe_size, - queue_size=args.queue_size, - **model_kwargs) + engine = launch_engine( + args.tp, + 1, + args.master_host, + args.master_port, + args.rpc_port, + get_model_fn(args.model), + batch_manager=BatchManagerForGeneration( + max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id + ), + pipe_size=args.pipe_size, + queue_size=args.queue_size, + **model_kwargs, + ) app.run(args.http_host, args.http_port) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/README.md b/examples/tutorial/opt/inference/script/process-opt-175b/README.md index bc3cba72df33..665c459fec69 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/README.md +++ b/examples/tutorial/opt/inference/script/process-opt-175b/README.md @@ -43,4 +43,3 @@ Finally, you will get 8 files in `` with following checksums: 5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt ``` - diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py index a17ddd4fa173..36c9001fe3f1 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py +++ b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py @@ -14,42 +14,45 @@ def load_json(path: str): def parse_shape_info(flat_dir: str): - data = load_json(os.path.join(flat_dir, 'shape.json')) + data = load_json(os.path.join(flat_dir, "shape.json")) flat_info = defaultdict(lambda: defaultdict(list)) for k, shape in data.items(): - matched = re.match(r'decoder.layers.\d+', k) + matched = re.match(r"decoder.layers.\d+", k) if matched is None: - flat_key = 'flat_param_0' + flat_key = "flat_param_0" else: - flat_key = f'{matched[0]}.flat_param_0' - flat_info[flat_key]['names'].append(k) - flat_info[flat_key]['shapes'].append(shape) - flat_info[flat_key]['numels'].append(int(np.prod(shape))) + flat_key = f"{matched[0]}.flat_param_0" + flat_info[flat_key]["names"].append(k) + flat_info[flat_key]["shapes"].append(shape) + flat_info[flat_key]["numels"].append(int(np.prod(shape))) return flat_info def convert(flat_dir: str, output_dir: str, part: int): - flat_path = os.path.join(flat_dir, f'reshard-model_part-{part}-shard0.pt') - output_path = os.path.join(output_dir, f'reshard-model_part-{part}.pt') - flat_meta = load_json(os.path.join(flat_dir, 'flat-meta.json')) + flat_path = os.path.join(flat_dir, f"reshard-model_part-{part}-shard0.pt") + output_path = os.path.join(output_dir, f"reshard-model_part-{part}.pt") + flat_meta = load_json(os.path.join(flat_dir, "flat-meta.json")) flat_sd = torch.load(flat_path) - print(f'Loaded flat state dict from {flat_path}') + print(f"Loaded flat state dict from {flat_path}") output_sd = {} for flat_key, param_meta in flat_meta.items(): - flat_param = flat_sd['model'][flat_key] - assert sum(param_meta['numels']) == flat_param.numel( + flat_param = flat_sd["model"][flat_key] + assert ( + sum(param_meta["numels"]) == flat_param.numel() ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' - for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])): + for name, shape, param in zip( + param_meta["names"], param_meta["shapes"], flat_param.split(param_meta["numels"]) + ): output_sd[name] = param.view(shape) torch.save(output_sd, output_path) - print(f'Saved unflat state dict to {output_path}') + print(f"Saved unflat state dict to {output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('flat_dir') - parser.add_argument('output_dir') - parser.add_argument('part', type=int) + parser.add_argument("flat_dir") + parser.add_argument("output_dir") + parser.add_argument("part", type=int) args = parser.parse_args() convert(args.flat_dir, args.output_dir, args.part) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json index 59d285565cfd..ce70451cc4e5 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json +++ b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json @@ -1 +1,6944 @@ -{"flat_param_0": {"names": ["decoder.embed_tokens.weight", "decoder.embed_positions.weight", "decoder.layer_norm.weight", "decoder.layer_norm.bias"], "shapes": [[6284, 12288], [2050, 12288], [12288], [12288]], "numels": [77217792, 25190400, 12288, 12288]}, "decoder.layers.0.flat_param_0": {"names": ["decoder.layers.0.self_attn.qkv_proj.weight", "decoder.layers.0.self_attn.qkv_proj.bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.self_attn_layer_norm.weight", "decoder.layers.0.self_attn_layer_norm.bias", "decoder.layers.0.fc1.weight", "decoder.layers.0.fc1.bias", "decoder.layers.0.fc2.weight", "decoder.layers.0.fc2.bias", "decoder.layers.0.final_layer_norm.weight", "decoder.layers.0.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.1.flat_param_0": {"names": ["decoder.layers.1.self_attn.qkv_proj.weight", "decoder.layers.1.self_attn.qkv_proj.bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.self_attn_layer_norm.weight", "decoder.layers.1.self_attn_layer_norm.bias", "decoder.layers.1.fc1.weight", "decoder.layers.1.fc1.bias", "decoder.layers.1.fc2.weight", "decoder.layers.1.fc2.bias", "decoder.layers.1.final_layer_norm.weight", "decoder.layers.1.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.2.flat_param_0": {"names": ["decoder.layers.2.self_attn.qkv_proj.weight", "decoder.layers.2.self_attn.qkv_proj.bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.self_attn_layer_norm.weight", "decoder.layers.2.self_attn_layer_norm.bias", "decoder.layers.2.fc1.weight", "decoder.layers.2.fc1.bias", "decoder.layers.2.fc2.weight", "decoder.layers.2.fc2.bias", "decoder.layers.2.final_layer_norm.weight", "decoder.layers.2.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.3.flat_param_0": {"names": ["decoder.layers.3.self_attn.qkv_proj.weight", "decoder.layers.3.self_attn.qkv_proj.bias", "decoder.layers.3.self_attn.out_proj.weight", "decoder.layers.3.self_attn.out_proj.bias", "decoder.layers.3.self_attn_layer_norm.weight", "decoder.layers.3.self_attn_layer_norm.bias", "decoder.layers.3.fc1.weight", "decoder.layers.3.fc1.bias", "decoder.layers.3.fc2.weight", "decoder.layers.3.fc2.bias", "decoder.layers.3.final_layer_norm.weight", "decoder.layers.3.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.4.flat_param_0": {"names": ["decoder.layers.4.self_attn.qkv_proj.weight", "decoder.layers.4.self_attn.qkv_proj.bias", "decoder.layers.4.self_attn.out_proj.weight", "decoder.layers.4.self_attn.out_proj.bias", "decoder.layers.4.self_attn_layer_norm.weight", "decoder.layers.4.self_attn_layer_norm.bias", "decoder.layers.4.fc1.weight", "decoder.layers.4.fc1.bias", "decoder.layers.4.fc2.weight", "decoder.layers.4.fc2.bias", "decoder.layers.4.final_layer_norm.weight", "decoder.layers.4.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.5.flat_param_0": {"names": ["decoder.layers.5.self_attn.qkv_proj.weight", "decoder.layers.5.self_attn.qkv_proj.bias", "decoder.layers.5.self_attn.out_proj.weight", "decoder.layers.5.self_attn.out_proj.bias", "decoder.layers.5.self_attn_layer_norm.weight", "decoder.layers.5.self_attn_layer_norm.bias", "decoder.layers.5.fc1.weight", "decoder.layers.5.fc1.bias", "decoder.layers.5.fc2.weight", "decoder.layers.5.fc2.bias", "decoder.layers.5.final_layer_norm.weight", "decoder.layers.5.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.6.flat_param_0": {"names": ["decoder.layers.6.self_attn.qkv_proj.weight", "decoder.layers.6.self_attn.qkv_proj.bias", "decoder.layers.6.self_attn.out_proj.weight", "decoder.layers.6.self_attn.out_proj.bias", "decoder.layers.6.self_attn_layer_norm.weight", "decoder.layers.6.self_attn_layer_norm.bias", "decoder.layers.6.fc1.weight", "decoder.layers.6.fc1.bias", "decoder.layers.6.fc2.weight", "decoder.layers.6.fc2.bias", "decoder.layers.6.final_layer_norm.weight", "decoder.layers.6.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.7.flat_param_0": {"names": ["decoder.layers.7.self_attn.qkv_proj.weight", "decoder.layers.7.self_attn.qkv_proj.bias", "decoder.layers.7.self_attn.out_proj.weight", "decoder.layers.7.self_attn.out_proj.bias", "decoder.layers.7.self_attn_layer_norm.weight", "decoder.layers.7.self_attn_layer_norm.bias", "decoder.layers.7.fc1.weight", "decoder.layers.7.fc1.bias", "decoder.layers.7.fc2.weight", "decoder.layers.7.fc2.bias", "decoder.layers.7.final_layer_norm.weight", "decoder.layers.7.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.8.flat_param_0": {"names": ["decoder.layers.8.self_attn.qkv_proj.weight", "decoder.layers.8.self_attn.qkv_proj.bias", "decoder.layers.8.self_attn.out_proj.weight", "decoder.layers.8.self_attn.out_proj.bias", "decoder.layers.8.self_attn_layer_norm.weight", "decoder.layers.8.self_attn_layer_norm.bias", "decoder.layers.8.fc1.weight", "decoder.layers.8.fc1.bias", "decoder.layers.8.fc2.weight", "decoder.layers.8.fc2.bias", "decoder.layers.8.final_layer_norm.weight", "decoder.layers.8.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.9.flat_param_0": {"names": ["decoder.layers.9.self_attn.qkv_proj.weight", "decoder.layers.9.self_attn.qkv_proj.bias", "decoder.layers.9.self_attn.out_proj.weight", "decoder.layers.9.self_attn.out_proj.bias", "decoder.layers.9.self_attn_layer_norm.weight", "decoder.layers.9.self_attn_layer_norm.bias", "decoder.layers.9.fc1.weight", "decoder.layers.9.fc1.bias", "decoder.layers.9.fc2.weight", "decoder.layers.9.fc2.bias", "decoder.layers.9.final_layer_norm.weight", "decoder.layers.9.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.10.flat_param_0": {"names": ["decoder.layers.10.self_attn.qkv_proj.weight", "decoder.layers.10.self_attn.qkv_proj.bias", "decoder.layers.10.self_attn.out_proj.weight", "decoder.layers.10.self_attn.out_proj.bias", "decoder.layers.10.self_attn_layer_norm.weight", "decoder.layers.10.self_attn_layer_norm.bias", "decoder.layers.10.fc1.weight", "decoder.layers.10.fc1.bias", "decoder.layers.10.fc2.weight", "decoder.layers.10.fc2.bias", "decoder.layers.10.final_layer_norm.weight", "decoder.layers.10.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.11.flat_param_0": {"names": ["decoder.layers.11.self_attn.qkv_proj.weight", "decoder.layers.11.self_attn.qkv_proj.bias", "decoder.layers.11.self_attn.out_proj.weight", "decoder.layers.11.self_attn.out_proj.bias", "decoder.layers.11.self_attn_layer_norm.weight", "decoder.layers.11.self_attn_layer_norm.bias", "decoder.layers.11.fc1.weight", "decoder.layers.11.fc1.bias", "decoder.layers.11.fc2.weight", "decoder.layers.11.fc2.bias", "decoder.layers.11.final_layer_norm.weight", "decoder.layers.11.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.12.flat_param_0": {"names": ["decoder.layers.12.self_attn.qkv_proj.weight", "decoder.layers.12.self_attn.qkv_proj.bias", "decoder.layers.12.self_attn.out_proj.weight", "decoder.layers.12.self_attn.out_proj.bias", "decoder.layers.12.self_attn_layer_norm.weight", "decoder.layers.12.self_attn_layer_norm.bias", "decoder.layers.12.fc1.weight", "decoder.layers.12.fc1.bias", "decoder.layers.12.fc2.weight", "decoder.layers.12.fc2.bias", "decoder.layers.12.final_layer_norm.weight", "decoder.layers.12.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.13.flat_param_0": {"names": ["decoder.layers.13.self_attn.qkv_proj.weight", "decoder.layers.13.self_attn.qkv_proj.bias", "decoder.layers.13.self_attn.out_proj.weight", "decoder.layers.13.self_attn.out_proj.bias", "decoder.layers.13.self_attn_layer_norm.weight", "decoder.layers.13.self_attn_layer_norm.bias", "decoder.layers.13.fc1.weight", "decoder.layers.13.fc1.bias", "decoder.layers.13.fc2.weight", "decoder.layers.13.fc2.bias", "decoder.layers.13.final_layer_norm.weight", "decoder.layers.13.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.14.flat_param_0": {"names": ["decoder.layers.14.self_attn.qkv_proj.weight", "decoder.layers.14.self_attn.qkv_proj.bias", "decoder.layers.14.self_attn.out_proj.weight", "decoder.layers.14.self_attn.out_proj.bias", "decoder.layers.14.self_attn_layer_norm.weight", "decoder.layers.14.self_attn_layer_norm.bias", "decoder.layers.14.fc1.weight", "decoder.layers.14.fc1.bias", "decoder.layers.14.fc2.weight", "decoder.layers.14.fc2.bias", "decoder.layers.14.final_layer_norm.weight", "decoder.layers.14.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.15.flat_param_0": {"names": ["decoder.layers.15.self_attn.qkv_proj.weight", "decoder.layers.15.self_attn.qkv_proj.bias", "decoder.layers.15.self_attn.out_proj.weight", "decoder.layers.15.self_attn.out_proj.bias", "decoder.layers.15.self_attn_layer_norm.weight", "decoder.layers.15.self_attn_layer_norm.bias", "decoder.layers.15.fc1.weight", "decoder.layers.15.fc1.bias", "decoder.layers.15.fc2.weight", "decoder.layers.15.fc2.bias", "decoder.layers.15.final_layer_norm.weight", "decoder.layers.15.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.16.flat_param_0": {"names": ["decoder.layers.16.self_attn.qkv_proj.weight", "decoder.layers.16.self_attn.qkv_proj.bias", "decoder.layers.16.self_attn.out_proj.weight", "decoder.layers.16.self_attn.out_proj.bias", "decoder.layers.16.self_attn_layer_norm.weight", "decoder.layers.16.self_attn_layer_norm.bias", "decoder.layers.16.fc1.weight", "decoder.layers.16.fc1.bias", "decoder.layers.16.fc2.weight", "decoder.layers.16.fc2.bias", "decoder.layers.16.final_layer_norm.weight", "decoder.layers.16.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.17.flat_param_0": {"names": ["decoder.layers.17.self_attn.qkv_proj.weight", "decoder.layers.17.self_attn.qkv_proj.bias", "decoder.layers.17.self_attn.out_proj.weight", "decoder.layers.17.self_attn.out_proj.bias", "decoder.layers.17.self_attn_layer_norm.weight", "decoder.layers.17.self_attn_layer_norm.bias", "decoder.layers.17.fc1.weight", "decoder.layers.17.fc1.bias", "decoder.layers.17.fc2.weight", "decoder.layers.17.fc2.bias", "decoder.layers.17.final_layer_norm.weight", "decoder.layers.17.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.18.flat_param_0": {"names": ["decoder.layers.18.self_attn.qkv_proj.weight", "decoder.layers.18.self_attn.qkv_proj.bias", "decoder.layers.18.self_attn.out_proj.weight", "decoder.layers.18.self_attn.out_proj.bias", "decoder.layers.18.self_attn_layer_norm.weight", "decoder.layers.18.self_attn_layer_norm.bias", "decoder.layers.18.fc1.weight", "decoder.layers.18.fc1.bias", "decoder.layers.18.fc2.weight", "decoder.layers.18.fc2.bias", "decoder.layers.18.final_layer_norm.weight", "decoder.layers.18.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.19.flat_param_0": {"names": ["decoder.layers.19.self_attn.qkv_proj.weight", "decoder.layers.19.self_attn.qkv_proj.bias", "decoder.layers.19.self_attn.out_proj.weight", "decoder.layers.19.self_attn.out_proj.bias", "decoder.layers.19.self_attn_layer_norm.weight", "decoder.layers.19.self_attn_layer_norm.bias", "decoder.layers.19.fc1.weight", "decoder.layers.19.fc1.bias", "decoder.layers.19.fc2.weight", "decoder.layers.19.fc2.bias", "decoder.layers.19.final_layer_norm.weight", "decoder.layers.19.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.20.flat_param_0": {"names": ["decoder.layers.20.self_attn.qkv_proj.weight", "decoder.layers.20.self_attn.qkv_proj.bias", "decoder.layers.20.self_attn.out_proj.weight", "decoder.layers.20.self_attn.out_proj.bias", "decoder.layers.20.self_attn_layer_norm.weight", "decoder.layers.20.self_attn_layer_norm.bias", "decoder.layers.20.fc1.weight", "decoder.layers.20.fc1.bias", "decoder.layers.20.fc2.weight", "decoder.layers.20.fc2.bias", "decoder.layers.20.final_layer_norm.weight", "decoder.layers.20.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.21.flat_param_0": {"names": ["decoder.layers.21.self_attn.qkv_proj.weight", "decoder.layers.21.self_attn.qkv_proj.bias", "decoder.layers.21.self_attn.out_proj.weight", "decoder.layers.21.self_attn.out_proj.bias", "decoder.layers.21.self_attn_layer_norm.weight", "decoder.layers.21.self_attn_layer_norm.bias", "decoder.layers.21.fc1.weight", "decoder.layers.21.fc1.bias", "decoder.layers.21.fc2.weight", "decoder.layers.21.fc2.bias", "decoder.layers.21.final_layer_norm.weight", "decoder.layers.21.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.22.flat_param_0": {"names": ["decoder.layers.22.self_attn.qkv_proj.weight", "decoder.layers.22.self_attn.qkv_proj.bias", "decoder.layers.22.self_attn.out_proj.weight", "decoder.layers.22.self_attn.out_proj.bias", "decoder.layers.22.self_attn_layer_norm.weight", "decoder.layers.22.self_attn_layer_norm.bias", "decoder.layers.22.fc1.weight", "decoder.layers.22.fc1.bias", "decoder.layers.22.fc2.weight", "decoder.layers.22.fc2.bias", "decoder.layers.22.final_layer_norm.weight", "decoder.layers.22.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.23.flat_param_0": {"names": ["decoder.layers.23.self_attn.qkv_proj.weight", "decoder.layers.23.self_attn.qkv_proj.bias", "decoder.layers.23.self_attn.out_proj.weight", "decoder.layers.23.self_attn.out_proj.bias", "decoder.layers.23.self_attn_layer_norm.weight", "decoder.layers.23.self_attn_layer_norm.bias", "decoder.layers.23.fc1.weight", "decoder.layers.23.fc1.bias", "decoder.layers.23.fc2.weight", "decoder.layers.23.fc2.bias", "decoder.layers.23.final_layer_norm.weight", "decoder.layers.23.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.24.flat_param_0": {"names": ["decoder.layers.24.self_attn.qkv_proj.weight", "decoder.layers.24.self_attn.qkv_proj.bias", "decoder.layers.24.self_attn.out_proj.weight", "decoder.layers.24.self_attn.out_proj.bias", "decoder.layers.24.self_attn_layer_norm.weight", "decoder.layers.24.self_attn_layer_norm.bias", "decoder.layers.24.fc1.weight", "decoder.layers.24.fc1.bias", "decoder.layers.24.fc2.weight", "decoder.layers.24.fc2.bias", "decoder.layers.24.final_layer_norm.weight", "decoder.layers.24.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.25.flat_param_0": {"names": ["decoder.layers.25.self_attn.qkv_proj.weight", "decoder.layers.25.self_attn.qkv_proj.bias", "decoder.layers.25.self_attn.out_proj.weight", "decoder.layers.25.self_attn.out_proj.bias", "decoder.layers.25.self_attn_layer_norm.weight", "decoder.layers.25.self_attn_layer_norm.bias", "decoder.layers.25.fc1.weight", "decoder.layers.25.fc1.bias", "decoder.layers.25.fc2.weight", "decoder.layers.25.fc2.bias", "decoder.layers.25.final_layer_norm.weight", "decoder.layers.25.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.26.flat_param_0": {"names": ["decoder.layers.26.self_attn.qkv_proj.weight", "decoder.layers.26.self_attn.qkv_proj.bias", "decoder.layers.26.self_attn.out_proj.weight", "decoder.layers.26.self_attn.out_proj.bias", "decoder.layers.26.self_attn_layer_norm.weight", "decoder.layers.26.self_attn_layer_norm.bias", "decoder.layers.26.fc1.weight", "decoder.layers.26.fc1.bias", "decoder.layers.26.fc2.weight", "decoder.layers.26.fc2.bias", "decoder.layers.26.final_layer_norm.weight", "decoder.layers.26.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.27.flat_param_0": {"names": ["decoder.layers.27.self_attn.qkv_proj.weight", "decoder.layers.27.self_attn.qkv_proj.bias", "decoder.layers.27.self_attn.out_proj.weight", "decoder.layers.27.self_attn.out_proj.bias", "decoder.layers.27.self_attn_layer_norm.weight", "decoder.layers.27.self_attn_layer_norm.bias", "decoder.layers.27.fc1.weight", "decoder.layers.27.fc1.bias", "decoder.layers.27.fc2.weight", "decoder.layers.27.fc2.bias", "decoder.layers.27.final_layer_norm.weight", "decoder.layers.27.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.28.flat_param_0": {"names": ["decoder.layers.28.self_attn.qkv_proj.weight", "decoder.layers.28.self_attn.qkv_proj.bias", "decoder.layers.28.self_attn.out_proj.weight", "decoder.layers.28.self_attn.out_proj.bias", "decoder.layers.28.self_attn_layer_norm.weight", "decoder.layers.28.self_attn_layer_norm.bias", "decoder.layers.28.fc1.weight", "decoder.layers.28.fc1.bias", "decoder.layers.28.fc2.weight", "decoder.layers.28.fc2.bias", "decoder.layers.28.final_layer_norm.weight", "decoder.layers.28.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.29.flat_param_0": {"names": ["decoder.layers.29.self_attn.qkv_proj.weight", "decoder.layers.29.self_attn.qkv_proj.bias", "decoder.layers.29.self_attn.out_proj.weight", "decoder.layers.29.self_attn.out_proj.bias", "decoder.layers.29.self_attn_layer_norm.weight", "decoder.layers.29.self_attn_layer_norm.bias", "decoder.layers.29.fc1.weight", "decoder.layers.29.fc1.bias", "decoder.layers.29.fc2.weight", "decoder.layers.29.fc2.bias", "decoder.layers.29.final_layer_norm.weight", "decoder.layers.29.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.30.flat_param_0": {"names": ["decoder.layers.30.self_attn.qkv_proj.weight", "decoder.layers.30.self_attn.qkv_proj.bias", "decoder.layers.30.self_attn.out_proj.weight", "decoder.layers.30.self_attn.out_proj.bias", "decoder.layers.30.self_attn_layer_norm.weight", "decoder.layers.30.self_attn_layer_norm.bias", "decoder.layers.30.fc1.weight", "decoder.layers.30.fc1.bias", "decoder.layers.30.fc2.weight", "decoder.layers.30.fc2.bias", "decoder.layers.30.final_layer_norm.weight", "decoder.layers.30.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.31.flat_param_0": {"names": ["decoder.layers.31.self_attn.qkv_proj.weight", "decoder.layers.31.self_attn.qkv_proj.bias", "decoder.layers.31.self_attn.out_proj.weight", "decoder.layers.31.self_attn.out_proj.bias", "decoder.layers.31.self_attn_layer_norm.weight", "decoder.layers.31.self_attn_layer_norm.bias", "decoder.layers.31.fc1.weight", "decoder.layers.31.fc1.bias", "decoder.layers.31.fc2.weight", "decoder.layers.31.fc2.bias", "decoder.layers.31.final_layer_norm.weight", "decoder.layers.31.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.32.flat_param_0": {"names": ["decoder.layers.32.self_attn.qkv_proj.weight", "decoder.layers.32.self_attn.qkv_proj.bias", "decoder.layers.32.self_attn.out_proj.weight", "decoder.layers.32.self_attn.out_proj.bias", "decoder.layers.32.self_attn_layer_norm.weight", "decoder.layers.32.self_attn_layer_norm.bias", "decoder.layers.32.fc1.weight", "decoder.layers.32.fc1.bias", "decoder.layers.32.fc2.weight", "decoder.layers.32.fc2.bias", "decoder.layers.32.final_layer_norm.weight", "decoder.layers.32.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.33.flat_param_0": {"names": ["decoder.layers.33.self_attn.qkv_proj.weight", "decoder.layers.33.self_attn.qkv_proj.bias", "decoder.layers.33.self_attn.out_proj.weight", "decoder.layers.33.self_attn.out_proj.bias", "decoder.layers.33.self_attn_layer_norm.weight", "decoder.layers.33.self_attn_layer_norm.bias", "decoder.layers.33.fc1.weight", "decoder.layers.33.fc1.bias", "decoder.layers.33.fc2.weight", "decoder.layers.33.fc2.bias", "decoder.layers.33.final_layer_norm.weight", "decoder.layers.33.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.34.flat_param_0": {"names": ["decoder.layers.34.self_attn.qkv_proj.weight", "decoder.layers.34.self_attn.qkv_proj.bias", "decoder.layers.34.self_attn.out_proj.weight", "decoder.layers.34.self_attn.out_proj.bias", "decoder.layers.34.self_attn_layer_norm.weight", "decoder.layers.34.self_attn_layer_norm.bias", "decoder.layers.34.fc1.weight", "decoder.layers.34.fc1.bias", "decoder.layers.34.fc2.weight", "decoder.layers.34.fc2.bias", "decoder.layers.34.final_layer_norm.weight", "decoder.layers.34.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.35.flat_param_0": {"names": ["decoder.layers.35.self_attn.qkv_proj.weight", "decoder.layers.35.self_attn.qkv_proj.bias", "decoder.layers.35.self_attn.out_proj.weight", "decoder.layers.35.self_attn.out_proj.bias", "decoder.layers.35.self_attn_layer_norm.weight", "decoder.layers.35.self_attn_layer_norm.bias", "decoder.layers.35.fc1.weight", "decoder.layers.35.fc1.bias", "decoder.layers.35.fc2.weight", "decoder.layers.35.fc2.bias", "decoder.layers.35.final_layer_norm.weight", "decoder.layers.35.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.36.flat_param_0": {"names": ["decoder.layers.36.self_attn.qkv_proj.weight", "decoder.layers.36.self_attn.qkv_proj.bias", "decoder.layers.36.self_attn.out_proj.weight", "decoder.layers.36.self_attn.out_proj.bias", "decoder.layers.36.self_attn_layer_norm.weight", "decoder.layers.36.self_attn_layer_norm.bias", "decoder.layers.36.fc1.weight", "decoder.layers.36.fc1.bias", "decoder.layers.36.fc2.weight", "decoder.layers.36.fc2.bias", "decoder.layers.36.final_layer_norm.weight", "decoder.layers.36.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.37.flat_param_0": {"names": ["decoder.layers.37.self_attn.qkv_proj.weight", "decoder.layers.37.self_attn.qkv_proj.bias", "decoder.layers.37.self_attn.out_proj.weight", "decoder.layers.37.self_attn.out_proj.bias", "decoder.layers.37.self_attn_layer_norm.weight", "decoder.layers.37.self_attn_layer_norm.bias", "decoder.layers.37.fc1.weight", "decoder.layers.37.fc1.bias", "decoder.layers.37.fc2.weight", "decoder.layers.37.fc2.bias", "decoder.layers.37.final_layer_norm.weight", "decoder.layers.37.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.38.flat_param_0": {"names": ["decoder.layers.38.self_attn.qkv_proj.weight", "decoder.layers.38.self_attn.qkv_proj.bias", "decoder.layers.38.self_attn.out_proj.weight", "decoder.layers.38.self_attn.out_proj.bias", "decoder.layers.38.self_attn_layer_norm.weight", "decoder.layers.38.self_attn_layer_norm.bias", "decoder.layers.38.fc1.weight", "decoder.layers.38.fc1.bias", "decoder.layers.38.fc2.weight", "decoder.layers.38.fc2.bias", "decoder.layers.38.final_layer_norm.weight", "decoder.layers.38.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.39.flat_param_0": {"names": ["decoder.layers.39.self_attn.qkv_proj.weight", "decoder.layers.39.self_attn.qkv_proj.bias", "decoder.layers.39.self_attn.out_proj.weight", "decoder.layers.39.self_attn.out_proj.bias", "decoder.layers.39.self_attn_layer_norm.weight", "decoder.layers.39.self_attn_layer_norm.bias", "decoder.layers.39.fc1.weight", "decoder.layers.39.fc1.bias", "decoder.layers.39.fc2.weight", "decoder.layers.39.fc2.bias", "decoder.layers.39.final_layer_norm.weight", "decoder.layers.39.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.40.flat_param_0": {"names": ["decoder.layers.40.self_attn.qkv_proj.weight", "decoder.layers.40.self_attn.qkv_proj.bias", "decoder.layers.40.self_attn.out_proj.weight", "decoder.layers.40.self_attn.out_proj.bias", "decoder.layers.40.self_attn_layer_norm.weight", "decoder.layers.40.self_attn_layer_norm.bias", "decoder.layers.40.fc1.weight", "decoder.layers.40.fc1.bias", "decoder.layers.40.fc2.weight", "decoder.layers.40.fc2.bias", "decoder.layers.40.final_layer_norm.weight", "decoder.layers.40.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.41.flat_param_0": {"names": ["decoder.layers.41.self_attn.qkv_proj.weight", "decoder.layers.41.self_attn.qkv_proj.bias", "decoder.layers.41.self_attn.out_proj.weight", "decoder.layers.41.self_attn.out_proj.bias", "decoder.layers.41.self_attn_layer_norm.weight", "decoder.layers.41.self_attn_layer_norm.bias", "decoder.layers.41.fc1.weight", "decoder.layers.41.fc1.bias", "decoder.layers.41.fc2.weight", "decoder.layers.41.fc2.bias", "decoder.layers.41.final_layer_norm.weight", "decoder.layers.41.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.42.flat_param_0": {"names": ["decoder.layers.42.self_attn.qkv_proj.weight", "decoder.layers.42.self_attn.qkv_proj.bias", "decoder.layers.42.self_attn.out_proj.weight", "decoder.layers.42.self_attn.out_proj.bias", "decoder.layers.42.self_attn_layer_norm.weight", "decoder.layers.42.self_attn_layer_norm.bias", "decoder.layers.42.fc1.weight", "decoder.layers.42.fc1.bias", "decoder.layers.42.fc2.weight", "decoder.layers.42.fc2.bias", "decoder.layers.42.final_layer_norm.weight", "decoder.layers.42.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.43.flat_param_0": {"names": ["decoder.layers.43.self_attn.qkv_proj.weight", "decoder.layers.43.self_attn.qkv_proj.bias", "decoder.layers.43.self_attn.out_proj.weight", "decoder.layers.43.self_attn.out_proj.bias", "decoder.layers.43.self_attn_layer_norm.weight", "decoder.layers.43.self_attn_layer_norm.bias", "decoder.layers.43.fc1.weight", "decoder.layers.43.fc1.bias", "decoder.layers.43.fc2.weight", "decoder.layers.43.fc2.bias", "decoder.layers.43.final_layer_norm.weight", "decoder.layers.43.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.44.flat_param_0": {"names": ["decoder.layers.44.self_attn.qkv_proj.weight", "decoder.layers.44.self_attn.qkv_proj.bias", "decoder.layers.44.self_attn.out_proj.weight", "decoder.layers.44.self_attn.out_proj.bias", "decoder.layers.44.self_attn_layer_norm.weight", "decoder.layers.44.self_attn_layer_norm.bias", "decoder.layers.44.fc1.weight", "decoder.layers.44.fc1.bias", "decoder.layers.44.fc2.weight", "decoder.layers.44.fc2.bias", "decoder.layers.44.final_layer_norm.weight", "decoder.layers.44.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.45.flat_param_0": {"names": ["decoder.layers.45.self_attn.qkv_proj.weight", "decoder.layers.45.self_attn.qkv_proj.bias", "decoder.layers.45.self_attn.out_proj.weight", "decoder.layers.45.self_attn.out_proj.bias", "decoder.layers.45.self_attn_layer_norm.weight", "decoder.layers.45.self_attn_layer_norm.bias", "decoder.layers.45.fc1.weight", "decoder.layers.45.fc1.bias", "decoder.layers.45.fc2.weight", "decoder.layers.45.fc2.bias", "decoder.layers.45.final_layer_norm.weight", "decoder.layers.45.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.46.flat_param_0": {"names": ["decoder.layers.46.self_attn.qkv_proj.weight", "decoder.layers.46.self_attn.qkv_proj.bias", "decoder.layers.46.self_attn.out_proj.weight", "decoder.layers.46.self_attn.out_proj.bias", "decoder.layers.46.self_attn_layer_norm.weight", "decoder.layers.46.self_attn_layer_norm.bias", "decoder.layers.46.fc1.weight", "decoder.layers.46.fc1.bias", "decoder.layers.46.fc2.weight", "decoder.layers.46.fc2.bias", "decoder.layers.46.final_layer_norm.weight", "decoder.layers.46.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.47.flat_param_0": {"names": ["decoder.layers.47.self_attn.qkv_proj.weight", "decoder.layers.47.self_attn.qkv_proj.bias", "decoder.layers.47.self_attn.out_proj.weight", "decoder.layers.47.self_attn.out_proj.bias", "decoder.layers.47.self_attn_layer_norm.weight", "decoder.layers.47.self_attn_layer_norm.bias", "decoder.layers.47.fc1.weight", "decoder.layers.47.fc1.bias", "decoder.layers.47.fc2.weight", "decoder.layers.47.fc2.bias", "decoder.layers.47.final_layer_norm.weight", "decoder.layers.47.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.48.flat_param_0": {"names": ["decoder.layers.48.self_attn.qkv_proj.weight", "decoder.layers.48.self_attn.qkv_proj.bias", "decoder.layers.48.self_attn.out_proj.weight", "decoder.layers.48.self_attn.out_proj.bias", "decoder.layers.48.self_attn_layer_norm.weight", "decoder.layers.48.self_attn_layer_norm.bias", "decoder.layers.48.fc1.weight", "decoder.layers.48.fc1.bias", "decoder.layers.48.fc2.weight", "decoder.layers.48.fc2.bias", "decoder.layers.48.final_layer_norm.weight", "decoder.layers.48.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.49.flat_param_0": {"names": ["decoder.layers.49.self_attn.qkv_proj.weight", "decoder.layers.49.self_attn.qkv_proj.bias", "decoder.layers.49.self_attn.out_proj.weight", "decoder.layers.49.self_attn.out_proj.bias", "decoder.layers.49.self_attn_layer_norm.weight", "decoder.layers.49.self_attn_layer_norm.bias", "decoder.layers.49.fc1.weight", "decoder.layers.49.fc1.bias", "decoder.layers.49.fc2.weight", "decoder.layers.49.fc2.bias", "decoder.layers.49.final_layer_norm.weight", "decoder.layers.49.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.50.flat_param_0": {"names": ["decoder.layers.50.self_attn.qkv_proj.weight", "decoder.layers.50.self_attn.qkv_proj.bias", "decoder.layers.50.self_attn.out_proj.weight", "decoder.layers.50.self_attn.out_proj.bias", "decoder.layers.50.self_attn_layer_norm.weight", "decoder.layers.50.self_attn_layer_norm.bias", "decoder.layers.50.fc1.weight", "decoder.layers.50.fc1.bias", "decoder.layers.50.fc2.weight", "decoder.layers.50.fc2.bias", "decoder.layers.50.final_layer_norm.weight", "decoder.layers.50.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.51.flat_param_0": {"names": ["decoder.layers.51.self_attn.qkv_proj.weight", "decoder.layers.51.self_attn.qkv_proj.bias", "decoder.layers.51.self_attn.out_proj.weight", "decoder.layers.51.self_attn.out_proj.bias", "decoder.layers.51.self_attn_layer_norm.weight", "decoder.layers.51.self_attn_layer_norm.bias", "decoder.layers.51.fc1.weight", "decoder.layers.51.fc1.bias", "decoder.layers.51.fc2.weight", "decoder.layers.51.fc2.bias", "decoder.layers.51.final_layer_norm.weight", "decoder.layers.51.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.52.flat_param_0": {"names": ["decoder.layers.52.self_attn.qkv_proj.weight", "decoder.layers.52.self_attn.qkv_proj.bias", "decoder.layers.52.self_attn.out_proj.weight", "decoder.layers.52.self_attn.out_proj.bias", "decoder.layers.52.self_attn_layer_norm.weight", "decoder.layers.52.self_attn_layer_norm.bias", "decoder.layers.52.fc1.weight", "decoder.layers.52.fc1.bias", "decoder.layers.52.fc2.weight", "decoder.layers.52.fc2.bias", "decoder.layers.52.final_layer_norm.weight", "decoder.layers.52.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.53.flat_param_0": {"names": ["decoder.layers.53.self_attn.qkv_proj.weight", "decoder.layers.53.self_attn.qkv_proj.bias", "decoder.layers.53.self_attn.out_proj.weight", "decoder.layers.53.self_attn.out_proj.bias", "decoder.layers.53.self_attn_layer_norm.weight", "decoder.layers.53.self_attn_layer_norm.bias", "decoder.layers.53.fc1.weight", "decoder.layers.53.fc1.bias", "decoder.layers.53.fc2.weight", "decoder.layers.53.fc2.bias", "decoder.layers.53.final_layer_norm.weight", "decoder.layers.53.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.54.flat_param_0": {"names": ["decoder.layers.54.self_attn.qkv_proj.weight", "decoder.layers.54.self_attn.qkv_proj.bias", "decoder.layers.54.self_attn.out_proj.weight", "decoder.layers.54.self_attn.out_proj.bias", "decoder.layers.54.self_attn_layer_norm.weight", "decoder.layers.54.self_attn_layer_norm.bias", "decoder.layers.54.fc1.weight", "decoder.layers.54.fc1.bias", "decoder.layers.54.fc2.weight", "decoder.layers.54.fc2.bias", "decoder.layers.54.final_layer_norm.weight", "decoder.layers.54.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.55.flat_param_0": {"names": ["decoder.layers.55.self_attn.qkv_proj.weight", "decoder.layers.55.self_attn.qkv_proj.bias", "decoder.layers.55.self_attn.out_proj.weight", "decoder.layers.55.self_attn.out_proj.bias", "decoder.layers.55.self_attn_layer_norm.weight", "decoder.layers.55.self_attn_layer_norm.bias", "decoder.layers.55.fc1.weight", "decoder.layers.55.fc1.bias", "decoder.layers.55.fc2.weight", "decoder.layers.55.fc2.bias", "decoder.layers.55.final_layer_norm.weight", "decoder.layers.55.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.56.flat_param_0": {"names": ["decoder.layers.56.self_attn.qkv_proj.weight", "decoder.layers.56.self_attn.qkv_proj.bias", "decoder.layers.56.self_attn.out_proj.weight", "decoder.layers.56.self_attn.out_proj.bias", "decoder.layers.56.self_attn_layer_norm.weight", "decoder.layers.56.self_attn_layer_norm.bias", "decoder.layers.56.fc1.weight", "decoder.layers.56.fc1.bias", "decoder.layers.56.fc2.weight", "decoder.layers.56.fc2.bias", "decoder.layers.56.final_layer_norm.weight", "decoder.layers.56.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.57.flat_param_0": {"names": ["decoder.layers.57.self_attn.qkv_proj.weight", "decoder.layers.57.self_attn.qkv_proj.bias", "decoder.layers.57.self_attn.out_proj.weight", "decoder.layers.57.self_attn.out_proj.bias", "decoder.layers.57.self_attn_layer_norm.weight", "decoder.layers.57.self_attn_layer_norm.bias", "decoder.layers.57.fc1.weight", "decoder.layers.57.fc1.bias", "decoder.layers.57.fc2.weight", "decoder.layers.57.fc2.bias", "decoder.layers.57.final_layer_norm.weight", "decoder.layers.57.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.58.flat_param_0": {"names": ["decoder.layers.58.self_attn.qkv_proj.weight", "decoder.layers.58.self_attn.qkv_proj.bias", "decoder.layers.58.self_attn.out_proj.weight", "decoder.layers.58.self_attn.out_proj.bias", "decoder.layers.58.self_attn_layer_norm.weight", "decoder.layers.58.self_attn_layer_norm.bias", "decoder.layers.58.fc1.weight", "decoder.layers.58.fc1.bias", "decoder.layers.58.fc2.weight", "decoder.layers.58.fc2.bias", "decoder.layers.58.final_layer_norm.weight", "decoder.layers.58.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.59.flat_param_0": {"names": ["decoder.layers.59.self_attn.qkv_proj.weight", "decoder.layers.59.self_attn.qkv_proj.bias", "decoder.layers.59.self_attn.out_proj.weight", "decoder.layers.59.self_attn.out_proj.bias", "decoder.layers.59.self_attn_layer_norm.weight", "decoder.layers.59.self_attn_layer_norm.bias", "decoder.layers.59.fc1.weight", "decoder.layers.59.fc1.bias", "decoder.layers.59.fc2.weight", "decoder.layers.59.fc2.bias", "decoder.layers.59.final_layer_norm.weight", "decoder.layers.59.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.60.flat_param_0": {"names": ["decoder.layers.60.self_attn.qkv_proj.weight", "decoder.layers.60.self_attn.qkv_proj.bias", "decoder.layers.60.self_attn.out_proj.weight", "decoder.layers.60.self_attn.out_proj.bias", "decoder.layers.60.self_attn_layer_norm.weight", "decoder.layers.60.self_attn_layer_norm.bias", "decoder.layers.60.fc1.weight", "decoder.layers.60.fc1.bias", "decoder.layers.60.fc2.weight", "decoder.layers.60.fc2.bias", "decoder.layers.60.final_layer_norm.weight", "decoder.layers.60.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.61.flat_param_0": {"names": ["decoder.layers.61.self_attn.qkv_proj.weight", "decoder.layers.61.self_attn.qkv_proj.bias", "decoder.layers.61.self_attn.out_proj.weight", "decoder.layers.61.self_attn.out_proj.bias", "decoder.layers.61.self_attn_layer_norm.weight", "decoder.layers.61.self_attn_layer_norm.bias", "decoder.layers.61.fc1.weight", "decoder.layers.61.fc1.bias", "decoder.layers.61.fc2.weight", "decoder.layers.61.fc2.bias", "decoder.layers.61.final_layer_norm.weight", "decoder.layers.61.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.62.flat_param_0": {"names": ["decoder.layers.62.self_attn.qkv_proj.weight", "decoder.layers.62.self_attn.qkv_proj.bias", "decoder.layers.62.self_attn.out_proj.weight", "decoder.layers.62.self_attn.out_proj.bias", "decoder.layers.62.self_attn_layer_norm.weight", "decoder.layers.62.self_attn_layer_norm.bias", "decoder.layers.62.fc1.weight", "decoder.layers.62.fc1.bias", "decoder.layers.62.fc2.weight", "decoder.layers.62.fc2.bias", "decoder.layers.62.final_layer_norm.weight", "decoder.layers.62.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.63.flat_param_0": {"names": ["decoder.layers.63.self_attn.qkv_proj.weight", "decoder.layers.63.self_attn.qkv_proj.bias", "decoder.layers.63.self_attn.out_proj.weight", "decoder.layers.63.self_attn.out_proj.bias", "decoder.layers.63.self_attn_layer_norm.weight", "decoder.layers.63.self_attn_layer_norm.bias", "decoder.layers.63.fc1.weight", "decoder.layers.63.fc1.bias", "decoder.layers.63.fc2.weight", "decoder.layers.63.fc2.bias", "decoder.layers.63.final_layer_norm.weight", "decoder.layers.63.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.64.flat_param_0": {"names": ["decoder.layers.64.self_attn.qkv_proj.weight", "decoder.layers.64.self_attn.qkv_proj.bias", "decoder.layers.64.self_attn.out_proj.weight", "decoder.layers.64.self_attn.out_proj.bias", "decoder.layers.64.self_attn_layer_norm.weight", "decoder.layers.64.self_attn_layer_norm.bias", "decoder.layers.64.fc1.weight", "decoder.layers.64.fc1.bias", "decoder.layers.64.fc2.weight", "decoder.layers.64.fc2.bias", "decoder.layers.64.final_layer_norm.weight", "decoder.layers.64.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.65.flat_param_0": {"names": ["decoder.layers.65.self_attn.qkv_proj.weight", "decoder.layers.65.self_attn.qkv_proj.bias", "decoder.layers.65.self_attn.out_proj.weight", "decoder.layers.65.self_attn.out_proj.bias", "decoder.layers.65.self_attn_layer_norm.weight", "decoder.layers.65.self_attn_layer_norm.bias", "decoder.layers.65.fc1.weight", "decoder.layers.65.fc1.bias", "decoder.layers.65.fc2.weight", "decoder.layers.65.fc2.bias", "decoder.layers.65.final_layer_norm.weight", "decoder.layers.65.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.66.flat_param_0": {"names": ["decoder.layers.66.self_attn.qkv_proj.weight", "decoder.layers.66.self_attn.qkv_proj.bias", "decoder.layers.66.self_attn.out_proj.weight", "decoder.layers.66.self_attn.out_proj.bias", "decoder.layers.66.self_attn_layer_norm.weight", "decoder.layers.66.self_attn_layer_norm.bias", "decoder.layers.66.fc1.weight", "decoder.layers.66.fc1.bias", "decoder.layers.66.fc2.weight", "decoder.layers.66.fc2.bias", "decoder.layers.66.final_layer_norm.weight", "decoder.layers.66.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.67.flat_param_0": {"names": ["decoder.layers.67.self_attn.qkv_proj.weight", "decoder.layers.67.self_attn.qkv_proj.bias", "decoder.layers.67.self_attn.out_proj.weight", "decoder.layers.67.self_attn.out_proj.bias", "decoder.layers.67.self_attn_layer_norm.weight", "decoder.layers.67.self_attn_layer_norm.bias", "decoder.layers.67.fc1.weight", "decoder.layers.67.fc1.bias", "decoder.layers.67.fc2.weight", "decoder.layers.67.fc2.bias", "decoder.layers.67.final_layer_norm.weight", "decoder.layers.67.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.68.flat_param_0": {"names": ["decoder.layers.68.self_attn.qkv_proj.weight", "decoder.layers.68.self_attn.qkv_proj.bias", "decoder.layers.68.self_attn.out_proj.weight", "decoder.layers.68.self_attn.out_proj.bias", "decoder.layers.68.self_attn_layer_norm.weight", "decoder.layers.68.self_attn_layer_norm.bias", "decoder.layers.68.fc1.weight", "decoder.layers.68.fc1.bias", "decoder.layers.68.fc2.weight", "decoder.layers.68.fc2.bias", "decoder.layers.68.final_layer_norm.weight", "decoder.layers.68.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.69.flat_param_0": {"names": ["decoder.layers.69.self_attn.qkv_proj.weight", "decoder.layers.69.self_attn.qkv_proj.bias", "decoder.layers.69.self_attn.out_proj.weight", "decoder.layers.69.self_attn.out_proj.bias", "decoder.layers.69.self_attn_layer_norm.weight", "decoder.layers.69.self_attn_layer_norm.bias", "decoder.layers.69.fc1.weight", "decoder.layers.69.fc1.bias", "decoder.layers.69.fc2.weight", "decoder.layers.69.fc2.bias", "decoder.layers.69.final_layer_norm.weight", "decoder.layers.69.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.70.flat_param_0": {"names": ["decoder.layers.70.self_attn.qkv_proj.weight", "decoder.layers.70.self_attn.qkv_proj.bias", "decoder.layers.70.self_attn.out_proj.weight", "decoder.layers.70.self_attn.out_proj.bias", "decoder.layers.70.self_attn_layer_norm.weight", "decoder.layers.70.self_attn_layer_norm.bias", "decoder.layers.70.fc1.weight", "decoder.layers.70.fc1.bias", "decoder.layers.70.fc2.weight", "decoder.layers.70.fc2.bias", "decoder.layers.70.final_layer_norm.weight", "decoder.layers.70.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.71.flat_param_0": {"names": ["decoder.layers.71.self_attn.qkv_proj.weight", "decoder.layers.71.self_attn.qkv_proj.bias", "decoder.layers.71.self_attn.out_proj.weight", "decoder.layers.71.self_attn.out_proj.bias", "decoder.layers.71.self_attn_layer_norm.weight", "decoder.layers.71.self_attn_layer_norm.bias", "decoder.layers.71.fc1.weight", "decoder.layers.71.fc1.bias", "decoder.layers.71.fc2.weight", "decoder.layers.71.fc2.bias", "decoder.layers.71.final_layer_norm.weight", "decoder.layers.71.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.72.flat_param_0": {"names": ["decoder.layers.72.self_attn.qkv_proj.weight", "decoder.layers.72.self_attn.qkv_proj.bias", "decoder.layers.72.self_attn.out_proj.weight", "decoder.layers.72.self_attn.out_proj.bias", "decoder.layers.72.self_attn_layer_norm.weight", "decoder.layers.72.self_attn_layer_norm.bias", "decoder.layers.72.fc1.weight", "decoder.layers.72.fc1.bias", "decoder.layers.72.fc2.weight", "decoder.layers.72.fc2.bias", "decoder.layers.72.final_layer_norm.weight", "decoder.layers.72.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.73.flat_param_0": {"names": ["decoder.layers.73.self_attn.qkv_proj.weight", "decoder.layers.73.self_attn.qkv_proj.bias", "decoder.layers.73.self_attn.out_proj.weight", "decoder.layers.73.self_attn.out_proj.bias", "decoder.layers.73.self_attn_layer_norm.weight", "decoder.layers.73.self_attn_layer_norm.bias", "decoder.layers.73.fc1.weight", "decoder.layers.73.fc1.bias", "decoder.layers.73.fc2.weight", "decoder.layers.73.fc2.bias", "decoder.layers.73.final_layer_norm.weight", "decoder.layers.73.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.74.flat_param_0": {"names": ["decoder.layers.74.self_attn.qkv_proj.weight", "decoder.layers.74.self_attn.qkv_proj.bias", "decoder.layers.74.self_attn.out_proj.weight", "decoder.layers.74.self_attn.out_proj.bias", "decoder.layers.74.self_attn_layer_norm.weight", "decoder.layers.74.self_attn_layer_norm.bias", "decoder.layers.74.fc1.weight", "decoder.layers.74.fc1.bias", "decoder.layers.74.fc2.weight", "decoder.layers.74.fc2.bias", "decoder.layers.74.final_layer_norm.weight", "decoder.layers.74.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.75.flat_param_0": {"names": ["decoder.layers.75.self_attn.qkv_proj.weight", "decoder.layers.75.self_attn.qkv_proj.bias", "decoder.layers.75.self_attn.out_proj.weight", "decoder.layers.75.self_attn.out_proj.bias", "decoder.layers.75.self_attn_layer_norm.weight", "decoder.layers.75.self_attn_layer_norm.bias", "decoder.layers.75.fc1.weight", "decoder.layers.75.fc1.bias", "decoder.layers.75.fc2.weight", "decoder.layers.75.fc2.bias", "decoder.layers.75.final_layer_norm.weight", "decoder.layers.75.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.76.flat_param_0": {"names": ["decoder.layers.76.self_attn.qkv_proj.weight", "decoder.layers.76.self_attn.qkv_proj.bias", "decoder.layers.76.self_attn.out_proj.weight", "decoder.layers.76.self_attn.out_proj.bias", "decoder.layers.76.self_attn_layer_norm.weight", "decoder.layers.76.self_attn_layer_norm.bias", "decoder.layers.76.fc1.weight", "decoder.layers.76.fc1.bias", "decoder.layers.76.fc2.weight", "decoder.layers.76.fc2.bias", "decoder.layers.76.final_layer_norm.weight", "decoder.layers.76.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.77.flat_param_0": {"names": ["decoder.layers.77.self_attn.qkv_proj.weight", "decoder.layers.77.self_attn.qkv_proj.bias", "decoder.layers.77.self_attn.out_proj.weight", "decoder.layers.77.self_attn.out_proj.bias", "decoder.layers.77.self_attn_layer_norm.weight", "decoder.layers.77.self_attn_layer_norm.bias", "decoder.layers.77.fc1.weight", "decoder.layers.77.fc1.bias", "decoder.layers.77.fc2.weight", "decoder.layers.77.fc2.bias", "decoder.layers.77.final_layer_norm.weight", "decoder.layers.77.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.78.flat_param_0": {"names": ["decoder.layers.78.self_attn.qkv_proj.weight", "decoder.layers.78.self_attn.qkv_proj.bias", "decoder.layers.78.self_attn.out_proj.weight", "decoder.layers.78.self_attn.out_proj.bias", "decoder.layers.78.self_attn_layer_norm.weight", "decoder.layers.78.self_attn_layer_norm.bias", "decoder.layers.78.fc1.weight", "decoder.layers.78.fc1.bias", "decoder.layers.78.fc2.weight", "decoder.layers.78.fc2.bias", "decoder.layers.78.final_layer_norm.weight", "decoder.layers.78.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.79.flat_param_0": {"names": ["decoder.layers.79.self_attn.qkv_proj.weight", "decoder.layers.79.self_attn.qkv_proj.bias", "decoder.layers.79.self_attn.out_proj.weight", "decoder.layers.79.self_attn.out_proj.bias", "decoder.layers.79.self_attn_layer_norm.weight", "decoder.layers.79.self_attn_layer_norm.bias", "decoder.layers.79.fc1.weight", "decoder.layers.79.fc1.bias", "decoder.layers.79.fc2.weight", "decoder.layers.79.fc2.bias", "decoder.layers.79.final_layer_norm.weight", "decoder.layers.79.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.80.flat_param_0": {"names": ["decoder.layers.80.self_attn.qkv_proj.weight", "decoder.layers.80.self_attn.qkv_proj.bias", "decoder.layers.80.self_attn.out_proj.weight", "decoder.layers.80.self_attn.out_proj.bias", "decoder.layers.80.self_attn_layer_norm.weight", "decoder.layers.80.self_attn_layer_norm.bias", "decoder.layers.80.fc1.weight", "decoder.layers.80.fc1.bias", "decoder.layers.80.fc2.weight", "decoder.layers.80.fc2.bias", "decoder.layers.80.final_layer_norm.weight", "decoder.layers.80.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.81.flat_param_0": {"names": ["decoder.layers.81.self_attn.qkv_proj.weight", "decoder.layers.81.self_attn.qkv_proj.bias", "decoder.layers.81.self_attn.out_proj.weight", "decoder.layers.81.self_attn.out_proj.bias", "decoder.layers.81.self_attn_layer_norm.weight", "decoder.layers.81.self_attn_layer_norm.bias", "decoder.layers.81.fc1.weight", "decoder.layers.81.fc1.bias", "decoder.layers.81.fc2.weight", "decoder.layers.81.fc2.bias", "decoder.layers.81.final_layer_norm.weight", "decoder.layers.81.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.82.flat_param_0": {"names": ["decoder.layers.82.self_attn.qkv_proj.weight", "decoder.layers.82.self_attn.qkv_proj.bias", "decoder.layers.82.self_attn.out_proj.weight", "decoder.layers.82.self_attn.out_proj.bias", "decoder.layers.82.self_attn_layer_norm.weight", "decoder.layers.82.self_attn_layer_norm.bias", "decoder.layers.82.fc1.weight", "decoder.layers.82.fc1.bias", "decoder.layers.82.fc2.weight", "decoder.layers.82.fc2.bias", "decoder.layers.82.final_layer_norm.weight", "decoder.layers.82.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.83.flat_param_0": {"names": ["decoder.layers.83.self_attn.qkv_proj.weight", "decoder.layers.83.self_attn.qkv_proj.bias", "decoder.layers.83.self_attn.out_proj.weight", "decoder.layers.83.self_attn.out_proj.bias", "decoder.layers.83.self_attn_layer_norm.weight", "decoder.layers.83.self_attn_layer_norm.bias", "decoder.layers.83.fc1.weight", "decoder.layers.83.fc1.bias", "decoder.layers.83.fc2.weight", "decoder.layers.83.fc2.bias", "decoder.layers.83.final_layer_norm.weight", "decoder.layers.83.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.84.flat_param_0": {"names": ["decoder.layers.84.self_attn.qkv_proj.weight", "decoder.layers.84.self_attn.qkv_proj.bias", "decoder.layers.84.self_attn.out_proj.weight", "decoder.layers.84.self_attn.out_proj.bias", "decoder.layers.84.self_attn_layer_norm.weight", "decoder.layers.84.self_attn_layer_norm.bias", "decoder.layers.84.fc1.weight", "decoder.layers.84.fc1.bias", "decoder.layers.84.fc2.weight", "decoder.layers.84.fc2.bias", "decoder.layers.84.final_layer_norm.weight", "decoder.layers.84.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.85.flat_param_0": {"names": ["decoder.layers.85.self_attn.qkv_proj.weight", "decoder.layers.85.self_attn.qkv_proj.bias", "decoder.layers.85.self_attn.out_proj.weight", "decoder.layers.85.self_attn.out_proj.bias", "decoder.layers.85.self_attn_layer_norm.weight", "decoder.layers.85.self_attn_layer_norm.bias", "decoder.layers.85.fc1.weight", "decoder.layers.85.fc1.bias", "decoder.layers.85.fc2.weight", "decoder.layers.85.fc2.bias", "decoder.layers.85.final_layer_norm.weight", "decoder.layers.85.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.86.flat_param_0": {"names": ["decoder.layers.86.self_attn.qkv_proj.weight", "decoder.layers.86.self_attn.qkv_proj.bias", "decoder.layers.86.self_attn.out_proj.weight", "decoder.layers.86.self_attn.out_proj.bias", "decoder.layers.86.self_attn_layer_norm.weight", "decoder.layers.86.self_attn_layer_norm.bias", "decoder.layers.86.fc1.weight", "decoder.layers.86.fc1.bias", "decoder.layers.86.fc2.weight", "decoder.layers.86.fc2.bias", "decoder.layers.86.final_layer_norm.weight", "decoder.layers.86.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.87.flat_param_0": {"names": ["decoder.layers.87.self_attn.qkv_proj.weight", "decoder.layers.87.self_attn.qkv_proj.bias", "decoder.layers.87.self_attn.out_proj.weight", "decoder.layers.87.self_attn.out_proj.bias", "decoder.layers.87.self_attn_layer_norm.weight", "decoder.layers.87.self_attn_layer_norm.bias", "decoder.layers.87.fc1.weight", "decoder.layers.87.fc1.bias", "decoder.layers.87.fc2.weight", "decoder.layers.87.fc2.bias", "decoder.layers.87.final_layer_norm.weight", "decoder.layers.87.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.88.flat_param_0": {"names": ["decoder.layers.88.self_attn.qkv_proj.weight", "decoder.layers.88.self_attn.qkv_proj.bias", "decoder.layers.88.self_attn.out_proj.weight", "decoder.layers.88.self_attn.out_proj.bias", "decoder.layers.88.self_attn_layer_norm.weight", "decoder.layers.88.self_attn_layer_norm.bias", "decoder.layers.88.fc1.weight", "decoder.layers.88.fc1.bias", "decoder.layers.88.fc2.weight", "decoder.layers.88.fc2.bias", "decoder.layers.88.final_layer_norm.weight", "decoder.layers.88.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.89.flat_param_0": {"names": ["decoder.layers.89.self_attn.qkv_proj.weight", "decoder.layers.89.self_attn.qkv_proj.bias", "decoder.layers.89.self_attn.out_proj.weight", "decoder.layers.89.self_attn.out_proj.bias", "decoder.layers.89.self_attn_layer_norm.weight", "decoder.layers.89.self_attn_layer_norm.bias", "decoder.layers.89.fc1.weight", "decoder.layers.89.fc1.bias", "decoder.layers.89.fc2.weight", "decoder.layers.89.fc2.bias", "decoder.layers.89.final_layer_norm.weight", "decoder.layers.89.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.90.flat_param_0": {"names": ["decoder.layers.90.self_attn.qkv_proj.weight", "decoder.layers.90.self_attn.qkv_proj.bias", "decoder.layers.90.self_attn.out_proj.weight", "decoder.layers.90.self_attn.out_proj.bias", "decoder.layers.90.self_attn_layer_norm.weight", "decoder.layers.90.self_attn_layer_norm.bias", "decoder.layers.90.fc1.weight", "decoder.layers.90.fc1.bias", "decoder.layers.90.fc2.weight", "decoder.layers.90.fc2.bias", "decoder.layers.90.final_layer_norm.weight", "decoder.layers.90.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.91.flat_param_0": {"names": ["decoder.layers.91.self_attn.qkv_proj.weight", "decoder.layers.91.self_attn.qkv_proj.bias", "decoder.layers.91.self_attn.out_proj.weight", "decoder.layers.91.self_attn.out_proj.bias", "decoder.layers.91.self_attn_layer_norm.weight", "decoder.layers.91.self_attn_layer_norm.bias", "decoder.layers.91.fc1.weight", "decoder.layers.91.fc1.bias", "decoder.layers.91.fc2.weight", "decoder.layers.91.fc2.bias", "decoder.layers.91.final_layer_norm.weight", "decoder.layers.91.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.92.flat_param_0": {"names": ["decoder.layers.92.self_attn.qkv_proj.weight", "decoder.layers.92.self_attn.qkv_proj.bias", "decoder.layers.92.self_attn.out_proj.weight", "decoder.layers.92.self_attn.out_proj.bias", "decoder.layers.92.self_attn_layer_norm.weight", "decoder.layers.92.self_attn_layer_norm.bias", "decoder.layers.92.fc1.weight", "decoder.layers.92.fc1.bias", "decoder.layers.92.fc2.weight", "decoder.layers.92.fc2.bias", "decoder.layers.92.final_layer_norm.weight", "decoder.layers.92.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.93.flat_param_0": {"names": ["decoder.layers.93.self_attn.qkv_proj.weight", "decoder.layers.93.self_attn.qkv_proj.bias", "decoder.layers.93.self_attn.out_proj.weight", "decoder.layers.93.self_attn.out_proj.bias", "decoder.layers.93.self_attn_layer_norm.weight", "decoder.layers.93.self_attn_layer_norm.bias", "decoder.layers.93.fc1.weight", "decoder.layers.93.fc1.bias", "decoder.layers.93.fc2.weight", "decoder.layers.93.fc2.bias", "decoder.layers.93.final_layer_norm.weight", "decoder.layers.93.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.94.flat_param_0": {"names": ["decoder.layers.94.self_attn.qkv_proj.weight", "decoder.layers.94.self_attn.qkv_proj.bias", "decoder.layers.94.self_attn.out_proj.weight", "decoder.layers.94.self_attn.out_proj.bias", "decoder.layers.94.self_attn_layer_norm.weight", "decoder.layers.94.self_attn_layer_norm.bias", "decoder.layers.94.fc1.weight", "decoder.layers.94.fc1.bias", "decoder.layers.94.fc2.weight", "decoder.layers.94.fc2.bias", "decoder.layers.94.final_layer_norm.weight", "decoder.layers.94.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.95.flat_param_0": {"names": ["decoder.layers.95.self_attn.qkv_proj.weight", "decoder.layers.95.self_attn.qkv_proj.bias", "decoder.layers.95.self_attn.out_proj.weight", "decoder.layers.95.self_attn.out_proj.bias", "decoder.layers.95.self_attn_layer_norm.weight", "decoder.layers.95.self_attn_layer_norm.bias", "decoder.layers.95.fc1.weight", "decoder.layers.95.fc1.bias", "decoder.layers.95.fc2.weight", "decoder.layers.95.fc2.bias", "decoder.layers.95.final_layer_norm.weight", "decoder.layers.95.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}} \ No newline at end of file +{ + "flat_param_0": { + "names": [ + "decoder.embed_tokens.weight", + "decoder.embed_positions.weight", + "decoder.layer_norm.weight", + "decoder.layer_norm.bias" + ], + "shapes": [ + [ + 6284, + 12288 + ], + [ + 2050, + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 77217792, + 25190400, + 12288, + 12288 + ] + }, + "decoder.layers.0.flat_param_0": { + "names": [ + "decoder.layers.0.self_attn.qkv_proj.weight", + "decoder.layers.0.self_attn.qkv_proj.bias", + "decoder.layers.0.self_attn.out_proj.weight", + "decoder.layers.0.self_attn.out_proj.bias", + "decoder.layers.0.self_attn_layer_norm.weight", + "decoder.layers.0.self_attn_layer_norm.bias", + "decoder.layers.0.fc1.weight", + "decoder.layers.0.fc1.bias", + "decoder.layers.0.fc2.weight", + "decoder.layers.0.fc2.bias", + "decoder.layers.0.final_layer_norm.weight", + "decoder.layers.0.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.1.flat_param_0": { + "names": [ + "decoder.layers.1.self_attn.qkv_proj.weight", + "decoder.layers.1.self_attn.qkv_proj.bias", + "decoder.layers.1.self_attn.out_proj.weight", + "decoder.layers.1.self_attn.out_proj.bias", + "decoder.layers.1.self_attn_layer_norm.weight", + "decoder.layers.1.self_attn_layer_norm.bias", + "decoder.layers.1.fc1.weight", + "decoder.layers.1.fc1.bias", + "decoder.layers.1.fc2.weight", + "decoder.layers.1.fc2.bias", + "decoder.layers.1.final_layer_norm.weight", + "decoder.layers.1.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.2.flat_param_0": { + "names": [ + "decoder.layers.2.self_attn.qkv_proj.weight", + "decoder.layers.2.self_attn.qkv_proj.bias", + "decoder.layers.2.self_attn.out_proj.weight", + "decoder.layers.2.self_attn.out_proj.bias", + "decoder.layers.2.self_attn_layer_norm.weight", + "decoder.layers.2.self_attn_layer_norm.bias", + "decoder.layers.2.fc1.weight", + "decoder.layers.2.fc1.bias", + "decoder.layers.2.fc2.weight", + "decoder.layers.2.fc2.bias", + "decoder.layers.2.final_layer_norm.weight", + "decoder.layers.2.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.3.flat_param_0": { + "names": [ + "decoder.layers.3.self_attn.qkv_proj.weight", + "decoder.layers.3.self_attn.qkv_proj.bias", + "decoder.layers.3.self_attn.out_proj.weight", + "decoder.layers.3.self_attn.out_proj.bias", + "decoder.layers.3.self_attn_layer_norm.weight", + "decoder.layers.3.self_attn_layer_norm.bias", + "decoder.layers.3.fc1.weight", + "decoder.layers.3.fc1.bias", + "decoder.layers.3.fc2.weight", + "decoder.layers.3.fc2.bias", + "decoder.layers.3.final_layer_norm.weight", + "decoder.layers.3.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.4.flat_param_0": { + "names": [ + "decoder.layers.4.self_attn.qkv_proj.weight", + "decoder.layers.4.self_attn.qkv_proj.bias", + "decoder.layers.4.self_attn.out_proj.weight", + "decoder.layers.4.self_attn.out_proj.bias", + "decoder.layers.4.self_attn_layer_norm.weight", + "decoder.layers.4.self_attn_layer_norm.bias", + "decoder.layers.4.fc1.weight", + "decoder.layers.4.fc1.bias", + "decoder.layers.4.fc2.weight", + "decoder.layers.4.fc2.bias", + "decoder.layers.4.final_layer_norm.weight", + "decoder.layers.4.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.5.flat_param_0": { + "names": [ + "decoder.layers.5.self_attn.qkv_proj.weight", + "decoder.layers.5.self_attn.qkv_proj.bias", + "decoder.layers.5.self_attn.out_proj.weight", + "decoder.layers.5.self_attn.out_proj.bias", + "decoder.layers.5.self_attn_layer_norm.weight", + "decoder.layers.5.self_attn_layer_norm.bias", + "decoder.layers.5.fc1.weight", + "decoder.layers.5.fc1.bias", + "decoder.layers.5.fc2.weight", + "decoder.layers.5.fc2.bias", + "decoder.layers.5.final_layer_norm.weight", + "decoder.layers.5.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.6.flat_param_0": { + "names": [ + "decoder.layers.6.self_attn.qkv_proj.weight", + "decoder.layers.6.self_attn.qkv_proj.bias", + "decoder.layers.6.self_attn.out_proj.weight", + "decoder.layers.6.self_attn.out_proj.bias", + "decoder.layers.6.self_attn_layer_norm.weight", + "decoder.layers.6.self_attn_layer_norm.bias", + "decoder.layers.6.fc1.weight", + "decoder.layers.6.fc1.bias", + "decoder.layers.6.fc2.weight", + "decoder.layers.6.fc2.bias", + "decoder.layers.6.final_layer_norm.weight", + "decoder.layers.6.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.7.flat_param_0": { + "names": [ + "decoder.layers.7.self_attn.qkv_proj.weight", + "decoder.layers.7.self_attn.qkv_proj.bias", + "decoder.layers.7.self_attn.out_proj.weight", + "decoder.layers.7.self_attn.out_proj.bias", + "decoder.layers.7.self_attn_layer_norm.weight", + "decoder.layers.7.self_attn_layer_norm.bias", + "decoder.layers.7.fc1.weight", + "decoder.layers.7.fc1.bias", + "decoder.layers.7.fc2.weight", + "decoder.layers.7.fc2.bias", + "decoder.layers.7.final_layer_norm.weight", + "decoder.layers.7.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.8.flat_param_0": { + "names": [ + "decoder.layers.8.self_attn.qkv_proj.weight", + "decoder.layers.8.self_attn.qkv_proj.bias", + "decoder.layers.8.self_attn.out_proj.weight", + "decoder.layers.8.self_attn.out_proj.bias", + "decoder.layers.8.self_attn_layer_norm.weight", + "decoder.layers.8.self_attn_layer_norm.bias", + "decoder.layers.8.fc1.weight", + "decoder.layers.8.fc1.bias", + "decoder.layers.8.fc2.weight", + "decoder.layers.8.fc2.bias", + "decoder.layers.8.final_layer_norm.weight", + "decoder.layers.8.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.9.flat_param_0": { + "names": [ + "decoder.layers.9.self_attn.qkv_proj.weight", + "decoder.layers.9.self_attn.qkv_proj.bias", + "decoder.layers.9.self_attn.out_proj.weight", + "decoder.layers.9.self_attn.out_proj.bias", + "decoder.layers.9.self_attn_layer_norm.weight", + "decoder.layers.9.self_attn_layer_norm.bias", + "decoder.layers.9.fc1.weight", + "decoder.layers.9.fc1.bias", + "decoder.layers.9.fc2.weight", + "decoder.layers.9.fc2.bias", + "decoder.layers.9.final_layer_norm.weight", + "decoder.layers.9.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.10.flat_param_0": { + "names": [ + "decoder.layers.10.self_attn.qkv_proj.weight", + "decoder.layers.10.self_attn.qkv_proj.bias", + "decoder.layers.10.self_attn.out_proj.weight", + "decoder.layers.10.self_attn.out_proj.bias", + "decoder.layers.10.self_attn_layer_norm.weight", + "decoder.layers.10.self_attn_layer_norm.bias", + "decoder.layers.10.fc1.weight", + "decoder.layers.10.fc1.bias", + "decoder.layers.10.fc2.weight", + "decoder.layers.10.fc2.bias", + "decoder.layers.10.final_layer_norm.weight", + "decoder.layers.10.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.11.flat_param_0": { + "names": [ + "decoder.layers.11.self_attn.qkv_proj.weight", + "decoder.layers.11.self_attn.qkv_proj.bias", + "decoder.layers.11.self_attn.out_proj.weight", + "decoder.layers.11.self_attn.out_proj.bias", + "decoder.layers.11.self_attn_layer_norm.weight", + "decoder.layers.11.self_attn_layer_norm.bias", + "decoder.layers.11.fc1.weight", + "decoder.layers.11.fc1.bias", + "decoder.layers.11.fc2.weight", + "decoder.layers.11.fc2.bias", + "decoder.layers.11.final_layer_norm.weight", + "decoder.layers.11.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.12.flat_param_0": { + "names": [ + "decoder.layers.12.self_attn.qkv_proj.weight", + "decoder.layers.12.self_attn.qkv_proj.bias", + "decoder.layers.12.self_attn.out_proj.weight", + "decoder.layers.12.self_attn.out_proj.bias", + "decoder.layers.12.self_attn_layer_norm.weight", + "decoder.layers.12.self_attn_layer_norm.bias", + "decoder.layers.12.fc1.weight", + "decoder.layers.12.fc1.bias", + "decoder.layers.12.fc2.weight", + "decoder.layers.12.fc2.bias", + "decoder.layers.12.final_layer_norm.weight", + "decoder.layers.12.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.13.flat_param_0": { + "names": [ + "decoder.layers.13.self_attn.qkv_proj.weight", + "decoder.layers.13.self_attn.qkv_proj.bias", + "decoder.layers.13.self_attn.out_proj.weight", + "decoder.layers.13.self_attn.out_proj.bias", + "decoder.layers.13.self_attn_layer_norm.weight", + "decoder.layers.13.self_attn_layer_norm.bias", + "decoder.layers.13.fc1.weight", + "decoder.layers.13.fc1.bias", + "decoder.layers.13.fc2.weight", + "decoder.layers.13.fc2.bias", + "decoder.layers.13.final_layer_norm.weight", + "decoder.layers.13.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.14.flat_param_0": { + "names": [ + "decoder.layers.14.self_attn.qkv_proj.weight", + "decoder.layers.14.self_attn.qkv_proj.bias", + "decoder.layers.14.self_attn.out_proj.weight", + "decoder.layers.14.self_attn.out_proj.bias", + "decoder.layers.14.self_attn_layer_norm.weight", + "decoder.layers.14.self_attn_layer_norm.bias", + "decoder.layers.14.fc1.weight", + "decoder.layers.14.fc1.bias", + "decoder.layers.14.fc2.weight", + "decoder.layers.14.fc2.bias", + "decoder.layers.14.final_layer_norm.weight", + "decoder.layers.14.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.15.flat_param_0": { + "names": [ + "decoder.layers.15.self_attn.qkv_proj.weight", + "decoder.layers.15.self_attn.qkv_proj.bias", + "decoder.layers.15.self_attn.out_proj.weight", + "decoder.layers.15.self_attn.out_proj.bias", + "decoder.layers.15.self_attn_layer_norm.weight", + "decoder.layers.15.self_attn_layer_norm.bias", + "decoder.layers.15.fc1.weight", + "decoder.layers.15.fc1.bias", + "decoder.layers.15.fc2.weight", + "decoder.layers.15.fc2.bias", + "decoder.layers.15.final_layer_norm.weight", + "decoder.layers.15.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.16.flat_param_0": { + "names": [ + "decoder.layers.16.self_attn.qkv_proj.weight", + "decoder.layers.16.self_attn.qkv_proj.bias", + "decoder.layers.16.self_attn.out_proj.weight", + "decoder.layers.16.self_attn.out_proj.bias", + "decoder.layers.16.self_attn_layer_norm.weight", + "decoder.layers.16.self_attn_layer_norm.bias", + "decoder.layers.16.fc1.weight", + "decoder.layers.16.fc1.bias", + "decoder.layers.16.fc2.weight", + "decoder.layers.16.fc2.bias", + "decoder.layers.16.final_layer_norm.weight", + "decoder.layers.16.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.17.flat_param_0": { + "names": [ + "decoder.layers.17.self_attn.qkv_proj.weight", + "decoder.layers.17.self_attn.qkv_proj.bias", + "decoder.layers.17.self_attn.out_proj.weight", + "decoder.layers.17.self_attn.out_proj.bias", + "decoder.layers.17.self_attn_layer_norm.weight", + "decoder.layers.17.self_attn_layer_norm.bias", + "decoder.layers.17.fc1.weight", + "decoder.layers.17.fc1.bias", + "decoder.layers.17.fc2.weight", + "decoder.layers.17.fc2.bias", + "decoder.layers.17.final_layer_norm.weight", + "decoder.layers.17.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.18.flat_param_0": { + "names": [ + "decoder.layers.18.self_attn.qkv_proj.weight", + "decoder.layers.18.self_attn.qkv_proj.bias", + "decoder.layers.18.self_attn.out_proj.weight", + "decoder.layers.18.self_attn.out_proj.bias", + "decoder.layers.18.self_attn_layer_norm.weight", + "decoder.layers.18.self_attn_layer_norm.bias", + "decoder.layers.18.fc1.weight", + "decoder.layers.18.fc1.bias", + "decoder.layers.18.fc2.weight", + "decoder.layers.18.fc2.bias", + "decoder.layers.18.final_layer_norm.weight", + "decoder.layers.18.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.19.flat_param_0": { + "names": [ + "decoder.layers.19.self_attn.qkv_proj.weight", + "decoder.layers.19.self_attn.qkv_proj.bias", + "decoder.layers.19.self_attn.out_proj.weight", + "decoder.layers.19.self_attn.out_proj.bias", + "decoder.layers.19.self_attn_layer_norm.weight", + "decoder.layers.19.self_attn_layer_norm.bias", + "decoder.layers.19.fc1.weight", + "decoder.layers.19.fc1.bias", + "decoder.layers.19.fc2.weight", + "decoder.layers.19.fc2.bias", + "decoder.layers.19.final_layer_norm.weight", + "decoder.layers.19.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.20.flat_param_0": { + "names": [ + "decoder.layers.20.self_attn.qkv_proj.weight", + "decoder.layers.20.self_attn.qkv_proj.bias", + "decoder.layers.20.self_attn.out_proj.weight", + "decoder.layers.20.self_attn.out_proj.bias", + "decoder.layers.20.self_attn_layer_norm.weight", + "decoder.layers.20.self_attn_layer_norm.bias", + "decoder.layers.20.fc1.weight", + "decoder.layers.20.fc1.bias", + "decoder.layers.20.fc2.weight", + "decoder.layers.20.fc2.bias", + "decoder.layers.20.final_layer_norm.weight", + "decoder.layers.20.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.21.flat_param_0": { + "names": [ + "decoder.layers.21.self_attn.qkv_proj.weight", + "decoder.layers.21.self_attn.qkv_proj.bias", + "decoder.layers.21.self_attn.out_proj.weight", + "decoder.layers.21.self_attn.out_proj.bias", + "decoder.layers.21.self_attn_layer_norm.weight", + "decoder.layers.21.self_attn_layer_norm.bias", + "decoder.layers.21.fc1.weight", + "decoder.layers.21.fc1.bias", + "decoder.layers.21.fc2.weight", + "decoder.layers.21.fc2.bias", + "decoder.layers.21.final_layer_norm.weight", + "decoder.layers.21.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.22.flat_param_0": { + "names": [ + "decoder.layers.22.self_attn.qkv_proj.weight", + "decoder.layers.22.self_attn.qkv_proj.bias", + "decoder.layers.22.self_attn.out_proj.weight", + "decoder.layers.22.self_attn.out_proj.bias", + "decoder.layers.22.self_attn_layer_norm.weight", + "decoder.layers.22.self_attn_layer_norm.bias", + "decoder.layers.22.fc1.weight", + "decoder.layers.22.fc1.bias", + "decoder.layers.22.fc2.weight", + "decoder.layers.22.fc2.bias", + "decoder.layers.22.final_layer_norm.weight", + "decoder.layers.22.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.23.flat_param_0": { + "names": [ + "decoder.layers.23.self_attn.qkv_proj.weight", + "decoder.layers.23.self_attn.qkv_proj.bias", + "decoder.layers.23.self_attn.out_proj.weight", + "decoder.layers.23.self_attn.out_proj.bias", + "decoder.layers.23.self_attn_layer_norm.weight", + "decoder.layers.23.self_attn_layer_norm.bias", + "decoder.layers.23.fc1.weight", + "decoder.layers.23.fc1.bias", + "decoder.layers.23.fc2.weight", + "decoder.layers.23.fc2.bias", + "decoder.layers.23.final_layer_norm.weight", + "decoder.layers.23.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.24.flat_param_0": { + "names": [ + "decoder.layers.24.self_attn.qkv_proj.weight", + "decoder.layers.24.self_attn.qkv_proj.bias", + "decoder.layers.24.self_attn.out_proj.weight", + "decoder.layers.24.self_attn.out_proj.bias", + "decoder.layers.24.self_attn_layer_norm.weight", + "decoder.layers.24.self_attn_layer_norm.bias", + "decoder.layers.24.fc1.weight", + "decoder.layers.24.fc1.bias", + "decoder.layers.24.fc2.weight", + "decoder.layers.24.fc2.bias", + "decoder.layers.24.final_layer_norm.weight", + "decoder.layers.24.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.25.flat_param_0": { + "names": [ + "decoder.layers.25.self_attn.qkv_proj.weight", + "decoder.layers.25.self_attn.qkv_proj.bias", + "decoder.layers.25.self_attn.out_proj.weight", + "decoder.layers.25.self_attn.out_proj.bias", + "decoder.layers.25.self_attn_layer_norm.weight", + "decoder.layers.25.self_attn_layer_norm.bias", + "decoder.layers.25.fc1.weight", + "decoder.layers.25.fc1.bias", + "decoder.layers.25.fc2.weight", + "decoder.layers.25.fc2.bias", + "decoder.layers.25.final_layer_norm.weight", + "decoder.layers.25.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.26.flat_param_0": { + "names": [ + "decoder.layers.26.self_attn.qkv_proj.weight", + "decoder.layers.26.self_attn.qkv_proj.bias", + "decoder.layers.26.self_attn.out_proj.weight", + "decoder.layers.26.self_attn.out_proj.bias", + "decoder.layers.26.self_attn_layer_norm.weight", + "decoder.layers.26.self_attn_layer_norm.bias", + "decoder.layers.26.fc1.weight", + "decoder.layers.26.fc1.bias", + "decoder.layers.26.fc2.weight", + "decoder.layers.26.fc2.bias", + "decoder.layers.26.final_layer_norm.weight", + "decoder.layers.26.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.27.flat_param_0": { + "names": [ + "decoder.layers.27.self_attn.qkv_proj.weight", + "decoder.layers.27.self_attn.qkv_proj.bias", + "decoder.layers.27.self_attn.out_proj.weight", + "decoder.layers.27.self_attn.out_proj.bias", + "decoder.layers.27.self_attn_layer_norm.weight", + "decoder.layers.27.self_attn_layer_norm.bias", + "decoder.layers.27.fc1.weight", + "decoder.layers.27.fc1.bias", + "decoder.layers.27.fc2.weight", + "decoder.layers.27.fc2.bias", + "decoder.layers.27.final_layer_norm.weight", + "decoder.layers.27.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.28.flat_param_0": { + "names": [ + "decoder.layers.28.self_attn.qkv_proj.weight", + "decoder.layers.28.self_attn.qkv_proj.bias", + "decoder.layers.28.self_attn.out_proj.weight", + "decoder.layers.28.self_attn.out_proj.bias", + "decoder.layers.28.self_attn_layer_norm.weight", + "decoder.layers.28.self_attn_layer_norm.bias", + "decoder.layers.28.fc1.weight", + "decoder.layers.28.fc1.bias", + "decoder.layers.28.fc2.weight", + "decoder.layers.28.fc2.bias", + "decoder.layers.28.final_layer_norm.weight", + "decoder.layers.28.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.29.flat_param_0": { + "names": [ + "decoder.layers.29.self_attn.qkv_proj.weight", + "decoder.layers.29.self_attn.qkv_proj.bias", + "decoder.layers.29.self_attn.out_proj.weight", + "decoder.layers.29.self_attn.out_proj.bias", + "decoder.layers.29.self_attn_layer_norm.weight", + "decoder.layers.29.self_attn_layer_norm.bias", + "decoder.layers.29.fc1.weight", + "decoder.layers.29.fc1.bias", + "decoder.layers.29.fc2.weight", + "decoder.layers.29.fc2.bias", + "decoder.layers.29.final_layer_norm.weight", + "decoder.layers.29.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.30.flat_param_0": { + "names": [ + "decoder.layers.30.self_attn.qkv_proj.weight", + "decoder.layers.30.self_attn.qkv_proj.bias", + "decoder.layers.30.self_attn.out_proj.weight", + "decoder.layers.30.self_attn.out_proj.bias", + "decoder.layers.30.self_attn_layer_norm.weight", + "decoder.layers.30.self_attn_layer_norm.bias", + "decoder.layers.30.fc1.weight", + "decoder.layers.30.fc1.bias", + "decoder.layers.30.fc2.weight", + "decoder.layers.30.fc2.bias", + "decoder.layers.30.final_layer_norm.weight", + "decoder.layers.30.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.31.flat_param_0": { + "names": [ + "decoder.layers.31.self_attn.qkv_proj.weight", + "decoder.layers.31.self_attn.qkv_proj.bias", + "decoder.layers.31.self_attn.out_proj.weight", + "decoder.layers.31.self_attn.out_proj.bias", + "decoder.layers.31.self_attn_layer_norm.weight", + "decoder.layers.31.self_attn_layer_norm.bias", + "decoder.layers.31.fc1.weight", + "decoder.layers.31.fc1.bias", + "decoder.layers.31.fc2.weight", + "decoder.layers.31.fc2.bias", + "decoder.layers.31.final_layer_norm.weight", + "decoder.layers.31.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.32.flat_param_0": { + "names": [ + "decoder.layers.32.self_attn.qkv_proj.weight", + "decoder.layers.32.self_attn.qkv_proj.bias", + "decoder.layers.32.self_attn.out_proj.weight", + "decoder.layers.32.self_attn.out_proj.bias", + "decoder.layers.32.self_attn_layer_norm.weight", + "decoder.layers.32.self_attn_layer_norm.bias", + "decoder.layers.32.fc1.weight", + "decoder.layers.32.fc1.bias", + "decoder.layers.32.fc2.weight", + "decoder.layers.32.fc2.bias", + "decoder.layers.32.final_layer_norm.weight", + "decoder.layers.32.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.33.flat_param_0": { + "names": [ + "decoder.layers.33.self_attn.qkv_proj.weight", + "decoder.layers.33.self_attn.qkv_proj.bias", + "decoder.layers.33.self_attn.out_proj.weight", + "decoder.layers.33.self_attn.out_proj.bias", + "decoder.layers.33.self_attn_layer_norm.weight", + "decoder.layers.33.self_attn_layer_norm.bias", + "decoder.layers.33.fc1.weight", + "decoder.layers.33.fc1.bias", + "decoder.layers.33.fc2.weight", + "decoder.layers.33.fc2.bias", + "decoder.layers.33.final_layer_norm.weight", + "decoder.layers.33.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.34.flat_param_0": { + "names": [ + "decoder.layers.34.self_attn.qkv_proj.weight", + "decoder.layers.34.self_attn.qkv_proj.bias", + "decoder.layers.34.self_attn.out_proj.weight", + "decoder.layers.34.self_attn.out_proj.bias", + "decoder.layers.34.self_attn_layer_norm.weight", + "decoder.layers.34.self_attn_layer_norm.bias", + "decoder.layers.34.fc1.weight", + "decoder.layers.34.fc1.bias", + "decoder.layers.34.fc2.weight", + "decoder.layers.34.fc2.bias", + "decoder.layers.34.final_layer_norm.weight", + "decoder.layers.34.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.35.flat_param_0": { + "names": [ + "decoder.layers.35.self_attn.qkv_proj.weight", + "decoder.layers.35.self_attn.qkv_proj.bias", + "decoder.layers.35.self_attn.out_proj.weight", + "decoder.layers.35.self_attn.out_proj.bias", + "decoder.layers.35.self_attn_layer_norm.weight", + "decoder.layers.35.self_attn_layer_norm.bias", + "decoder.layers.35.fc1.weight", + "decoder.layers.35.fc1.bias", + "decoder.layers.35.fc2.weight", + "decoder.layers.35.fc2.bias", + "decoder.layers.35.final_layer_norm.weight", + "decoder.layers.35.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.36.flat_param_0": { + "names": [ + "decoder.layers.36.self_attn.qkv_proj.weight", + "decoder.layers.36.self_attn.qkv_proj.bias", + "decoder.layers.36.self_attn.out_proj.weight", + "decoder.layers.36.self_attn.out_proj.bias", + "decoder.layers.36.self_attn_layer_norm.weight", + "decoder.layers.36.self_attn_layer_norm.bias", + "decoder.layers.36.fc1.weight", + "decoder.layers.36.fc1.bias", + "decoder.layers.36.fc2.weight", + "decoder.layers.36.fc2.bias", + "decoder.layers.36.final_layer_norm.weight", + "decoder.layers.36.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.37.flat_param_0": { + "names": [ + "decoder.layers.37.self_attn.qkv_proj.weight", + "decoder.layers.37.self_attn.qkv_proj.bias", + "decoder.layers.37.self_attn.out_proj.weight", + "decoder.layers.37.self_attn.out_proj.bias", + "decoder.layers.37.self_attn_layer_norm.weight", + "decoder.layers.37.self_attn_layer_norm.bias", + "decoder.layers.37.fc1.weight", + "decoder.layers.37.fc1.bias", + "decoder.layers.37.fc2.weight", + "decoder.layers.37.fc2.bias", + "decoder.layers.37.final_layer_norm.weight", + "decoder.layers.37.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.38.flat_param_0": { + "names": [ + "decoder.layers.38.self_attn.qkv_proj.weight", + "decoder.layers.38.self_attn.qkv_proj.bias", + "decoder.layers.38.self_attn.out_proj.weight", + "decoder.layers.38.self_attn.out_proj.bias", + "decoder.layers.38.self_attn_layer_norm.weight", + "decoder.layers.38.self_attn_layer_norm.bias", + "decoder.layers.38.fc1.weight", + "decoder.layers.38.fc1.bias", + "decoder.layers.38.fc2.weight", + "decoder.layers.38.fc2.bias", + "decoder.layers.38.final_layer_norm.weight", + "decoder.layers.38.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.39.flat_param_0": { + "names": [ + "decoder.layers.39.self_attn.qkv_proj.weight", + "decoder.layers.39.self_attn.qkv_proj.bias", + "decoder.layers.39.self_attn.out_proj.weight", + "decoder.layers.39.self_attn.out_proj.bias", + "decoder.layers.39.self_attn_layer_norm.weight", + "decoder.layers.39.self_attn_layer_norm.bias", + "decoder.layers.39.fc1.weight", + "decoder.layers.39.fc1.bias", + "decoder.layers.39.fc2.weight", + "decoder.layers.39.fc2.bias", + "decoder.layers.39.final_layer_norm.weight", + "decoder.layers.39.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.40.flat_param_0": { + "names": [ + "decoder.layers.40.self_attn.qkv_proj.weight", + "decoder.layers.40.self_attn.qkv_proj.bias", + "decoder.layers.40.self_attn.out_proj.weight", + "decoder.layers.40.self_attn.out_proj.bias", + "decoder.layers.40.self_attn_layer_norm.weight", + "decoder.layers.40.self_attn_layer_norm.bias", + "decoder.layers.40.fc1.weight", + "decoder.layers.40.fc1.bias", + "decoder.layers.40.fc2.weight", + "decoder.layers.40.fc2.bias", + "decoder.layers.40.final_layer_norm.weight", + "decoder.layers.40.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.41.flat_param_0": { + "names": [ + "decoder.layers.41.self_attn.qkv_proj.weight", + "decoder.layers.41.self_attn.qkv_proj.bias", + "decoder.layers.41.self_attn.out_proj.weight", + "decoder.layers.41.self_attn.out_proj.bias", + "decoder.layers.41.self_attn_layer_norm.weight", + "decoder.layers.41.self_attn_layer_norm.bias", + "decoder.layers.41.fc1.weight", + "decoder.layers.41.fc1.bias", + "decoder.layers.41.fc2.weight", + "decoder.layers.41.fc2.bias", + "decoder.layers.41.final_layer_norm.weight", + "decoder.layers.41.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.42.flat_param_0": { + "names": [ + "decoder.layers.42.self_attn.qkv_proj.weight", + "decoder.layers.42.self_attn.qkv_proj.bias", + "decoder.layers.42.self_attn.out_proj.weight", + "decoder.layers.42.self_attn.out_proj.bias", + "decoder.layers.42.self_attn_layer_norm.weight", + "decoder.layers.42.self_attn_layer_norm.bias", + "decoder.layers.42.fc1.weight", + "decoder.layers.42.fc1.bias", + "decoder.layers.42.fc2.weight", + "decoder.layers.42.fc2.bias", + "decoder.layers.42.final_layer_norm.weight", + "decoder.layers.42.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.43.flat_param_0": { + "names": [ + "decoder.layers.43.self_attn.qkv_proj.weight", + "decoder.layers.43.self_attn.qkv_proj.bias", + "decoder.layers.43.self_attn.out_proj.weight", + "decoder.layers.43.self_attn.out_proj.bias", + "decoder.layers.43.self_attn_layer_norm.weight", + "decoder.layers.43.self_attn_layer_norm.bias", + "decoder.layers.43.fc1.weight", + "decoder.layers.43.fc1.bias", + "decoder.layers.43.fc2.weight", + "decoder.layers.43.fc2.bias", + "decoder.layers.43.final_layer_norm.weight", + "decoder.layers.43.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.44.flat_param_0": { + "names": [ + "decoder.layers.44.self_attn.qkv_proj.weight", + "decoder.layers.44.self_attn.qkv_proj.bias", + "decoder.layers.44.self_attn.out_proj.weight", + "decoder.layers.44.self_attn.out_proj.bias", + "decoder.layers.44.self_attn_layer_norm.weight", + "decoder.layers.44.self_attn_layer_norm.bias", + "decoder.layers.44.fc1.weight", + "decoder.layers.44.fc1.bias", + "decoder.layers.44.fc2.weight", + "decoder.layers.44.fc2.bias", + "decoder.layers.44.final_layer_norm.weight", + "decoder.layers.44.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.45.flat_param_0": { + "names": [ + "decoder.layers.45.self_attn.qkv_proj.weight", + "decoder.layers.45.self_attn.qkv_proj.bias", + "decoder.layers.45.self_attn.out_proj.weight", + "decoder.layers.45.self_attn.out_proj.bias", + "decoder.layers.45.self_attn_layer_norm.weight", + "decoder.layers.45.self_attn_layer_norm.bias", + "decoder.layers.45.fc1.weight", + "decoder.layers.45.fc1.bias", + "decoder.layers.45.fc2.weight", + "decoder.layers.45.fc2.bias", + "decoder.layers.45.final_layer_norm.weight", + "decoder.layers.45.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.46.flat_param_0": { + "names": [ + "decoder.layers.46.self_attn.qkv_proj.weight", + "decoder.layers.46.self_attn.qkv_proj.bias", + "decoder.layers.46.self_attn.out_proj.weight", + "decoder.layers.46.self_attn.out_proj.bias", + "decoder.layers.46.self_attn_layer_norm.weight", + "decoder.layers.46.self_attn_layer_norm.bias", + "decoder.layers.46.fc1.weight", + "decoder.layers.46.fc1.bias", + "decoder.layers.46.fc2.weight", + "decoder.layers.46.fc2.bias", + "decoder.layers.46.final_layer_norm.weight", + "decoder.layers.46.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.47.flat_param_0": { + "names": [ + "decoder.layers.47.self_attn.qkv_proj.weight", + "decoder.layers.47.self_attn.qkv_proj.bias", + "decoder.layers.47.self_attn.out_proj.weight", + "decoder.layers.47.self_attn.out_proj.bias", + "decoder.layers.47.self_attn_layer_norm.weight", + "decoder.layers.47.self_attn_layer_norm.bias", + "decoder.layers.47.fc1.weight", + "decoder.layers.47.fc1.bias", + "decoder.layers.47.fc2.weight", + "decoder.layers.47.fc2.bias", + "decoder.layers.47.final_layer_norm.weight", + "decoder.layers.47.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.48.flat_param_0": { + "names": [ + "decoder.layers.48.self_attn.qkv_proj.weight", + "decoder.layers.48.self_attn.qkv_proj.bias", + "decoder.layers.48.self_attn.out_proj.weight", + "decoder.layers.48.self_attn.out_proj.bias", + "decoder.layers.48.self_attn_layer_norm.weight", + "decoder.layers.48.self_attn_layer_norm.bias", + "decoder.layers.48.fc1.weight", + "decoder.layers.48.fc1.bias", + "decoder.layers.48.fc2.weight", + "decoder.layers.48.fc2.bias", + "decoder.layers.48.final_layer_norm.weight", + "decoder.layers.48.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.49.flat_param_0": { + "names": [ + "decoder.layers.49.self_attn.qkv_proj.weight", + "decoder.layers.49.self_attn.qkv_proj.bias", + "decoder.layers.49.self_attn.out_proj.weight", + "decoder.layers.49.self_attn.out_proj.bias", + "decoder.layers.49.self_attn_layer_norm.weight", + "decoder.layers.49.self_attn_layer_norm.bias", + "decoder.layers.49.fc1.weight", + "decoder.layers.49.fc1.bias", + "decoder.layers.49.fc2.weight", + "decoder.layers.49.fc2.bias", + "decoder.layers.49.final_layer_norm.weight", + "decoder.layers.49.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.50.flat_param_0": { + "names": [ + "decoder.layers.50.self_attn.qkv_proj.weight", + "decoder.layers.50.self_attn.qkv_proj.bias", + "decoder.layers.50.self_attn.out_proj.weight", + "decoder.layers.50.self_attn.out_proj.bias", + "decoder.layers.50.self_attn_layer_norm.weight", + "decoder.layers.50.self_attn_layer_norm.bias", + "decoder.layers.50.fc1.weight", + "decoder.layers.50.fc1.bias", + "decoder.layers.50.fc2.weight", + "decoder.layers.50.fc2.bias", + "decoder.layers.50.final_layer_norm.weight", + "decoder.layers.50.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.51.flat_param_0": { + "names": [ + "decoder.layers.51.self_attn.qkv_proj.weight", + "decoder.layers.51.self_attn.qkv_proj.bias", + "decoder.layers.51.self_attn.out_proj.weight", + "decoder.layers.51.self_attn.out_proj.bias", + "decoder.layers.51.self_attn_layer_norm.weight", + "decoder.layers.51.self_attn_layer_norm.bias", + "decoder.layers.51.fc1.weight", + "decoder.layers.51.fc1.bias", + "decoder.layers.51.fc2.weight", + "decoder.layers.51.fc2.bias", + "decoder.layers.51.final_layer_norm.weight", + "decoder.layers.51.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.52.flat_param_0": { + "names": [ + "decoder.layers.52.self_attn.qkv_proj.weight", + "decoder.layers.52.self_attn.qkv_proj.bias", + "decoder.layers.52.self_attn.out_proj.weight", + "decoder.layers.52.self_attn.out_proj.bias", + "decoder.layers.52.self_attn_layer_norm.weight", + "decoder.layers.52.self_attn_layer_norm.bias", + "decoder.layers.52.fc1.weight", + "decoder.layers.52.fc1.bias", + "decoder.layers.52.fc2.weight", + "decoder.layers.52.fc2.bias", + "decoder.layers.52.final_layer_norm.weight", + "decoder.layers.52.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.53.flat_param_0": { + "names": [ + "decoder.layers.53.self_attn.qkv_proj.weight", + "decoder.layers.53.self_attn.qkv_proj.bias", + "decoder.layers.53.self_attn.out_proj.weight", + "decoder.layers.53.self_attn.out_proj.bias", + "decoder.layers.53.self_attn_layer_norm.weight", + "decoder.layers.53.self_attn_layer_norm.bias", + "decoder.layers.53.fc1.weight", + "decoder.layers.53.fc1.bias", + "decoder.layers.53.fc2.weight", + "decoder.layers.53.fc2.bias", + "decoder.layers.53.final_layer_norm.weight", + "decoder.layers.53.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.54.flat_param_0": { + "names": [ + "decoder.layers.54.self_attn.qkv_proj.weight", + "decoder.layers.54.self_attn.qkv_proj.bias", + "decoder.layers.54.self_attn.out_proj.weight", + "decoder.layers.54.self_attn.out_proj.bias", + "decoder.layers.54.self_attn_layer_norm.weight", + "decoder.layers.54.self_attn_layer_norm.bias", + "decoder.layers.54.fc1.weight", + "decoder.layers.54.fc1.bias", + "decoder.layers.54.fc2.weight", + "decoder.layers.54.fc2.bias", + "decoder.layers.54.final_layer_norm.weight", + "decoder.layers.54.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.55.flat_param_0": { + "names": [ + "decoder.layers.55.self_attn.qkv_proj.weight", + "decoder.layers.55.self_attn.qkv_proj.bias", + "decoder.layers.55.self_attn.out_proj.weight", + "decoder.layers.55.self_attn.out_proj.bias", + "decoder.layers.55.self_attn_layer_norm.weight", + "decoder.layers.55.self_attn_layer_norm.bias", + "decoder.layers.55.fc1.weight", + "decoder.layers.55.fc1.bias", + "decoder.layers.55.fc2.weight", + "decoder.layers.55.fc2.bias", + "decoder.layers.55.final_layer_norm.weight", + "decoder.layers.55.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.56.flat_param_0": { + "names": [ + "decoder.layers.56.self_attn.qkv_proj.weight", + "decoder.layers.56.self_attn.qkv_proj.bias", + "decoder.layers.56.self_attn.out_proj.weight", + "decoder.layers.56.self_attn.out_proj.bias", + "decoder.layers.56.self_attn_layer_norm.weight", + "decoder.layers.56.self_attn_layer_norm.bias", + "decoder.layers.56.fc1.weight", + "decoder.layers.56.fc1.bias", + "decoder.layers.56.fc2.weight", + "decoder.layers.56.fc2.bias", + "decoder.layers.56.final_layer_norm.weight", + "decoder.layers.56.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.57.flat_param_0": { + "names": [ + "decoder.layers.57.self_attn.qkv_proj.weight", + "decoder.layers.57.self_attn.qkv_proj.bias", + "decoder.layers.57.self_attn.out_proj.weight", + "decoder.layers.57.self_attn.out_proj.bias", + "decoder.layers.57.self_attn_layer_norm.weight", + "decoder.layers.57.self_attn_layer_norm.bias", + "decoder.layers.57.fc1.weight", + "decoder.layers.57.fc1.bias", + "decoder.layers.57.fc2.weight", + "decoder.layers.57.fc2.bias", + "decoder.layers.57.final_layer_norm.weight", + "decoder.layers.57.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.58.flat_param_0": { + "names": [ + "decoder.layers.58.self_attn.qkv_proj.weight", + "decoder.layers.58.self_attn.qkv_proj.bias", + "decoder.layers.58.self_attn.out_proj.weight", + "decoder.layers.58.self_attn.out_proj.bias", + "decoder.layers.58.self_attn_layer_norm.weight", + "decoder.layers.58.self_attn_layer_norm.bias", + "decoder.layers.58.fc1.weight", + "decoder.layers.58.fc1.bias", + "decoder.layers.58.fc2.weight", + "decoder.layers.58.fc2.bias", + "decoder.layers.58.final_layer_norm.weight", + "decoder.layers.58.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.59.flat_param_0": { + "names": [ + "decoder.layers.59.self_attn.qkv_proj.weight", + "decoder.layers.59.self_attn.qkv_proj.bias", + "decoder.layers.59.self_attn.out_proj.weight", + "decoder.layers.59.self_attn.out_proj.bias", + "decoder.layers.59.self_attn_layer_norm.weight", + "decoder.layers.59.self_attn_layer_norm.bias", + "decoder.layers.59.fc1.weight", + "decoder.layers.59.fc1.bias", + "decoder.layers.59.fc2.weight", + "decoder.layers.59.fc2.bias", + "decoder.layers.59.final_layer_norm.weight", + "decoder.layers.59.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.60.flat_param_0": { + "names": [ + "decoder.layers.60.self_attn.qkv_proj.weight", + "decoder.layers.60.self_attn.qkv_proj.bias", + "decoder.layers.60.self_attn.out_proj.weight", + "decoder.layers.60.self_attn.out_proj.bias", + "decoder.layers.60.self_attn_layer_norm.weight", + "decoder.layers.60.self_attn_layer_norm.bias", + "decoder.layers.60.fc1.weight", + "decoder.layers.60.fc1.bias", + "decoder.layers.60.fc2.weight", + "decoder.layers.60.fc2.bias", + "decoder.layers.60.final_layer_norm.weight", + "decoder.layers.60.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.61.flat_param_0": { + "names": [ + "decoder.layers.61.self_attn.qkv_proj.weight", + "decoder.layers.61.self_attn.qkv_proj.bias", + "decoder.layers.61.self_attn.out_proj.weight", + "decoder.layers.61.self_attn.out_proj.bias", + "decoder.layers.61.self_attn_layer_norm.weight", + "decoder.layers.61.self_attn_layer_norm.bias", + "decoder.layers.61.fc1.weight", + "decoder.layers.61.fc1.bias", + "decoder.layers.61.fc2.weight", + "decoder.layers.61.fc2.bias", + "decoder.layers.61.final_layer_norm.weight", + "decoder.layers.61.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.62.flat_param_0": { + "names": [ + "decoder.layers.62.self_attn.qkv_proj.weight", + "decoder.layers.62.self_attn.qkv_proj.bias", + "decoder.layers.62.self_attn.out_proj.weight", + "decoder.layers.62.self_attn.out_proj.bias", + "decoder.layers.62.self_attn_layer_norm.weight", + "decoder.layers.62.self_attn_layer_norm.bias", + "decoder.layers.62.fc1.weight", + "decoder.layers.62.fc1.bias", + "decoder.layers.62.fc2.weight", + "decoder.layers.62.fc2.bias", + "decoder.layers.62.final_layer_norm.weight", + "decoder.layers.62.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.63.flat_param_0": { + "names": [ + "decoder.layers.63.self_attn.qkv_proj.weight", + "decoder.layers.63.self_attn.qkv_proj.bias", + "decoder.layers.63.self_attn.out_proj.weight", + "decoder.layers.63.self_attn.out_proj.bias", + "decoder.layers.63.self_attn_layer_norm.weight", + "decoder.layers.63.self_attn_layer_norm.bias", + "decoder.layers.63.fc1.weight", + "decoder.layers.63.fc1.bias", + "decoder.layers.63.fc2.weight", + "decoder.layers.63.fc2.bias", + "decoder.layers.63.final_layer_norm.weight", + "decoder.layers.63.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.64.flat_param_0": { + "names": [ + "decoder.layers.64.self_attn.qkv_proj.weight", + "decoder.layers.64.self_attn.qkv_proj.bias", + "decoder.layers.64.self_attn.out_proj.weight", + "decoder.layers.64.self_attn.out_proj.bias", + "decoder.layers.64.self_attn_layer_norm.weight", + "decoder.layers.64.self_attn_layer_norm.bias", + "decoder.layers.64.fc1.weight", + "decoder.layers.64.fc1.bias", + "decoder.layers.64.fc2.weight", + "decoder.layers.64.fc2.bias", + "decoder.layers.64.final_layer_norm.weight", + "decoder.layers.64.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.65.flat_param_0": { + "names": [ + "decoder.layers.65.self_attn.qkv_proj.weight", + "decoder.layers.65.self_attn.qkv_proj.bias", + "decoder.layers.65.self_attn.out_proj.weight", + "decoder.layers.65.self_attn.out_proj.bias", + "decoder.layers.65.self_attn_layer_norm.weight", + "decoder.layers.65.self_attn_layer_norm.bias", + "decoder.layers.65.fc1.weight", + "decoder.layers.65.fc1.bias", + "decoder.layers.65.fc2.weight", + "decoder.layers.65.fc2.bias", + "decoder.layers.65.final_layer_norm.weight", + "decoder.layers.65.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.66.flat_param_0": { + "names": [ + "decoder.layers.66.self_attn.qkv_proj.weight", + "decoder.layers.66.self_attn.qkv_proj.bias", + "decoder.layers.66.self_attn.out_proj.weight", + "decoder.layers.66.self_attn.out_proj.bias", + "decoder.layers.66.self_attn_layer_norm.weight", + "decoder.layers.66.self_attn_layer_norm.bias", + "decoder.layers.66.fc1.weight", + "decoder.layers.66.fc1.bias", + "decoder.layers.66.fc2.weight", + "decoder.layers.66.fc2.bias", + "decoder.layers.66.final_layer_norm.weight", + "decoder.layers.66.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.67.flat_param_0": { + "names": [ + "decoder.layers.67.self_attn.qkv_proj.weight", + "decoder.layers.67.self_attn.qkv_proj.bias", + "decoder.layers.67.self_attn.out_proj.weight", + "decoder.layers.67.self_attn.out_proj.bias", + "decoder.layers.67.self_attn_layer_norm.weight", + "decoder.layers.67.self_attn_layer_norm.bias", + "decoder.layers.67.fc1.weight", + "decoder.layers.67.fc1.bias", + "decoder.layers.67.fc2.weight", + "decoder.layers.67.fc2.bias", + "decoder.layers.67.final_layer_norm.weight", + "decoder.layers.67.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.68.flat_param_0": { + "names": [ + "decoder.layers.68.self_attn.qkv_proj.weight", + "decoder.layers.68.self_attn.qkv_proj.bias", + "decoder.layers.68.self_attn.out_proj.weight", + "decoder.layers.68.self_attn.out_proj.bias", + "decoder.layers.68.self_attn_layer_norm.weight", + "decoder.layers.68.self_attn_layer_norm.bias", + "decoder.layers.68.fc1.weight", + "decoder.layers.68.fc1.bias", + "decoder.layers.68.fc2.weight", + "decoder.layers.68.fc2.bias", + "decoder.layers.68.final_layer_norm.weight", + "decoder.layers.68.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.69.flat_param_0": { + "names": [ + "decoder.layers.69.self_attn.qkv_proj.weight", + "decoder.layers.69.self_attn.qkv_proj.bias", + "decoder.layers.69.self_attn.out_proj.weight", + "decoder.layers.69.self_attn.out_proj.bias", + "decoder.layers.69.self_attn_layer_norm.weight", + "decoder.layers.69.self_attn_layer_norm.bias", + "decoder.layers.69.fc1.weight", + "decoder.layers.69.fc1.bias", + "decoder.layers.69.fc2.weight", + "decoder.layers.69.fc2.bias", + "decoder.layers.69.final_layer_norm.weight", + "decoder.layers.69.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.70.flat_param_0": { + "names": [ + "decoder.layers.70.self_attn.qkv_proj.weight", + "decoder.layers.70.self_attn.qkv_proj.bias", + "decoder.layers.70.self_attn.out_proj.weight", + "decoder.layers.70.self_attn.out_proj.bias", + "decoder.layers.70.self_attn_layer_norm.weight", + "decoder.layers.70.self_attn_layer_norm.bias", + "decoder.layers.70.fc1.weight", + "decoder.layers.70.fc1.bias", + "decoder.layers.70.fc2.weight", + "decoder.layers.70.fc2.bias", + "decoder.layers.70.final_layer_norm.weight", + "decoder.layers.70.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.71.flat_param_0": { + "names": [ + "decoder.layers.71.self_attn.qkv_proj.weight", + "decoder.layers.71.self_attn.qkv_proj.bias", + "decoder.layers.71.self_attn.out_proj.weight", + "decoder.layers.71.self_attn.out_proj.bias", + "decoder.layers.71.self_attn_layer_norm.weight", + "decoder.layers.71.self_attn_layer_norm.bias", + "decoder.layers.71.fc1.weight", + "decoder.layers.71.fc1.bias", + "decoder.layers.71.fc2.weight", + "decoder.layers.71.fc2.bias", + "decoder.layers.71.final_layer_norm.weight", + "decoder.layers.71.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.72.flat_param_0": { + "names": [ + "decoder.layers.72.self_attn.qkv_proj.weight", + "decoder.layers.72.self_attn.qkv_proj.bias", + "decoder.layers.72.self_attn.out_proj.weight", + "decoder.layers.72.self_attn.out_proj.bias", + "decoder.layers.72.self_attn_layer_norm.weight", + "decoder.layers.72.self_attn_layer_norm.bias", + "decoder.layers.72.fc1.weight", + "decoder.layers.72.fc1.bias", + "decoder.layers.72.fc2.weight", + "decoder.layers.72.fc2.bias", + "decoder.layers.72.final_layer_norm.weight", + "decoder.layers.72.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.73.flat_param_0": { + "names": [ + "decoder.layers.73.self_attn.qkv_proj.weight", + "decoder.layers.73.self_attn.qkv_proj.bias", + "decoder.layers.73.self_attn.out_proj.weight", + "decoder.layers.73.self_attn.out_proj.bias", + "decoder.layers.73.self_attn_layer_norm.weight", + "decoder.layers.73.self_attn_layer_norm.bias", + "decoder.layers.73.fc1.weight", + "decoder.layers.73.fc1.bias", + "decoder.layers.73.fc2.weight", + "decoder.layers.73.fc2.bias", + "decoder.layers.73.final_layer_norm.weight", + "decoder.layers.73.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.74.flat_param_0": { + "names": [ + "decoder.layers.74.self_attn.qkv_proj.weight", + "decoder.layers.74.self_attn.qkv_proj.bias", + "decoder.layers.74.self_attn.out_proj.weight", + "decoder.layers.74.self_attn.out_proj.bias", + "decoder.layers.74.self_attn_layer_norm.weight", + "decoder.layers.74.self_attn_layer_norm.bias", + "decoder.layers.74.fc1.weight", + "decoder.layers.74.fc1.bias", + "decoder.layers.74.fc2.weight", + "decoder.layers.74.fc2.bias", + "decoder.layers.74.final_layer_norm.weight", + "decoder.layers.74.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.75.flat_param_0": { + "names": [ + "decoder.layers.75.self_attn.qkv_proj.weight", + "decoder.layers.75.self_attn.qkv_proj.bias", + "decoder.layers.75.self_attn.out_proj.weight", + "decoder.layers.75.self_attn.out_proj.bias", + "decoder.layers.75.self_attn_layer_norm.weight", + "decoder.layers.75.self_attn_layer_norm.bias", + "decoder.layers.75.fc1.weight", + "decoder.layers.75.fc1.bias", + "decoder.layers.75.fc2.weight", + "decoder.layers.75.fc2.bias", + "decoder.layers.75.final_layer_norm.weight", + "decoder.layers.75.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.76.flat_param_0": { + "names": [ + "decoder.layers.76.self_attn.qkv_proj.weight", + "decoder.layers.76.self_attn.qkv_proj.bias", + "decoder.layers.76.self_attn.out_proj.weight", + "decoder.layers.76.self_attn.out_proj.bias", + "decoder.layers.76.self_attn_layer_norm.weight", + "decoder.layers.76.self_attn_layer_norm.bias", + "decoder.layers.76.fc1.weight", + "decoder.layers.76.fc1.bias", + "decoder.layers.76.fc2.weight", + "decoder.layers.76.fc2.bias", + "decoder.layers.76.final_layer_norm.weight", + "decoder.layers.76.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.77.flat_param_0": { + "names": [ + "decoder.layers.77.self_attn.qkv_proj.weight", + "decoder.layers.77.self_attn.qkv_proj.bias", + "decoder.layers.77.self_attn.out_proj.weight", + "decoder.layers.77.self_attn.out_proj.bias", + "decoder.layers.77.self_attn_layer_norm.weight", + "decoder.layers.77.self_attn_layer_norm.bias", + "decoder.layers.77.fc1.weight", + "decoder.layers.77.fc1.bias", + "decoder.layers.77.fc2.weight", + "decoder.layers.77.fc2.bias", + "decoder.layers.77.final_layer_norm.weight", + "decoder.layers.77.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.78.flat_param_0": { + "names": [ + "decoder.layers.78.self_attn.qkv_proj.weight", + "decoder.layers.78.self_attn.qkv_proj.bias", + "decoder.layers.78.self_attn.out_proj.weight", + "decoder.layers.78.self_attn.out_proj.bias", + "decoder.layers.78.self_attn_layer_norm.weight", + "decoder.layers.78.self_attn_layer_norm.bias", + "decoder.layers.78.fc1.weight", + "decoder.layers.78.fc1.bias", + "decoder.layers.78.fc2.weight", + "decoder.layers.78.fc2.bias", + "decoder.layers.78.final_layer_norm.weight", + "decoder.layers.78.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.79.flat_param_0": { + "names": [ + "decoder.layers.79.self_attn.qkv_proj.weight", + "decoder.layers.79.self_attn.qkv_proj.bias", + "decoder.layers.79.self_attn.out_proj.weight", + "decoder.layers.79.self_attn.out_proj.bias", + "decoder.layers.79.self_attn_layer_norm.weight", + "decoder.layers.79.self_attn_layer_norm.bias", + "decoder.layers.79.fc1.weight", + "decoder.layers.79.fc1.bias", + "decoder.layers.79.fc2.weight", + "decoder.layers.79.fc2.bias", + "decoder.layers.79.final_layer_norm.weight", + "decoder.layers.79.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.80.flat_param_0": { + "names": [ + "decoder.layers.80.self_attn.qkv_proj.weight", + "decoder.layers.80.self_attn.qkv_proj.bias", + "decoder.layers.80.self_attn.out_proj.weight", + "decoder.layers.80.self_attn.out_proj.bias", + "decoder.layers.80.self_attn_layer_norm.weight", + "decoder.layers.80.self_attn_layer_norm.bias", + "decoder.layers.80.fc1.weight", + "decoder.layers.80.fc1.bias", + "decoder.layers.80.fc2.weight", + "decoder.layers.80.fc2.bias", + "decoder.layers.80.final_layer_norm.weight", + "decoder.layers.80.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.81.flat_param_0": { + "names": [ + "decoder.layers.81.self_attn.qkv_proj.weight", + "decoder.layers.81.self_attn.qkv_proj.bias", + "decoder.layers.81.self_attn.out_proj.weight", + "decoder.layers.81.self_attn.out_proj.bias", + "decoder.layers.81.self_attn_layer_norm.weight", + "decoder.layers.81.self_attn_layer_norm.bias", + "decoder.layers.81.fc1.weight", + "decoder.layers.81.fc1.bias", + "decoder.layers.81.fc2.weight", + "decoder.layers.81.fc2.bias", + "decoder.layers.81.final_layer_norm.weight", + "decoder.layers.81.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.82.flat_param_0": { + "names": [ + "decoder.layers.82.self_attn.qkv_proj.weight", + "decoder.layers.82.self_attn.qkv_proj.bias", + "decoder.layers.82.self_attn.out_proj.weight", + "decoder.layers.82.self_attn.out_proj.bias", + "decoder.layers.82.self_attn_layer_norm.weight", + "decoder.layers.82.self_attn_layer_norm.bias", + "decoder.layers.82.fc1.weight", + "decoder.layers.82.fc1.bias", + "decoder.layers.82.fc2.weight", + "decoder.layers.82.fc2.bias", + "decoder.layers.82.final_layer_norm.weight", + "decoder.layers.82.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.83.flat_param_0": { + "names": [ + "decoder.layers.83.self_attn.qkv_proj.weight", + "decoder.layers.83.self_attn.qkv_proj.bias", + "decoder.layers.83.self_attn.out_proj.weight", + "decoder.layers.83.self_attn.out_proj.bias", + "decoder.layers.83.self_attn_layer_norm.weight", + "decoder.layers.83.self_attn_layer_norm.bias", + "decoder.layers.83.fc1.weight", + "decoder.layers.83.fc1.bias", + "decoder.layers.83.fc2.weight", + "decoder.layers.83.fc2.bias", + "decoder.layers.83.final_layer_norm.weight", + "decoder.layers.83.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.84.flat_param_0": { + "names": [ + "decoder.layers.84.self_attn.qkv_proj.weight", + "decoder.layers.84.self_attn.qkv_proj.bias", + "decoder.layers.84.self_attn.out_proj.weight", + "decoder.layers.84.self_attn.out_proj.bias", + "decoder.layers.84.self_attn_layer_norm.weight", + "decoder.layers.84.self_attn_layer_norm.bias", + "decoder.layers.84.fc1.weight", + "decoder.layers.84.fc1.bias", + "decoder.layers.84.fc2.weight", + "decoder.layers.84.fc2.bias", + "decoder.layers.84.final_layer_norm.weight", + "decoder.layers.84.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.85.flat_param_0": { + "names": [ + "decoder.layers.85.self_attn.qkv_proj.weight", + "decoder.layers.85.self_attn.qkv_proj.bias", + "decoder.layers.85.self_attn.out_proj.weight", + "decoder.layers.85.self_attn.out_proj.bias", + "decoder.layers.85.self_attn_layer_norm.weight", + "decoder.layers.85.self_attn_layer_norm.bias", + "decoder.layers.85.fc1.weight", + "decoder.layers.85.fc1.bias", + "decoder.layers.85.fc2.weight", + "decoder.layers.85.fc2.bias", + "decoder.layers.85.final_layer_norm.weight", + "decoder.layers.85.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.86.flat_param_0": { + "names": [ + "decoder.layers.86.self_attn.qkv_proj.weight", + "decoder.layers.86.self_attn.qkv_proj.bias", + "decoder.layers.86.self_attn.out_proj.weight", + "decoder.layers.86.self_attn.out_proj.bias", + "decoder.layers.86.self_attn_layer_norm.weight", + "decoder.layers.86.self_attn_layer_norm.bias", + "decoder.layers.86.fc1.weight", + "decoder.layers.86.fc1.bias", + "decoder.layers.86.fc2.weight", + "decoder.layers.86.fc2.bias", + "decoder.layers.86.final_layer_norm.weight", + "decoder.layers.86.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.87.flat_param_0": { + "names": [ + "decoder.layers.87.self_attn.qkv_proj.weight", + "decoder.layers.87.self_attn.qkv_proj.bias", + "decoder.layers.87.self_attn.out_proj.weight", + "decoder.layers.87.self_attn.out_proj.bias", + "decoder.layers.87.self_attn_layer_norm.weight", + "decoder.layers.87.self_attn_layer_norm.bias", + "decoder.layers.87.fc1.weight", + "decoder.layers.87.fc1.bias", + "decoder.layers.87.fc2.weight", + "decoder.layers.87.fc2.bias", + "decoder.layers.87.final_layer_norm.weight", + "decoder.layers.87.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.88.flat_param_0": { + "names": [ + "decoder.layers.88.self_attn.qkv_proj.weight", + "decoder.layers.88.self_attn.qkv_proj.bias", + "decoder.layers.88.self_attn.out_proj.weight", + "decoder.layers.88.self_attn.out_proj.bias", + "decoder.layers.88.self_attn_layer_norm.weight", + "decoder.layers.88.self_attn_layer_norm.bias", + "decoder.layers.88.fc1.weight", + "decoder.layers.88.fc1.bias", + "decoder.layers.88.fc2.weight", + "decoder.layers.88.fc2.bias", + "decoder.layers.88.final_layer_norm.weight", + "decoder.layers.88.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.89.flat_param_0": { + "names": [ + "decoder.layers.89.self_attn.qkv_proj.weight", + "decoder.layers.89.self_attn.qkv_proj.bias", + "decoder.layers.89.self_attn.out_proj.weight", + "decoder.layers.89.self_attn.out_proj.bias", + "decoder.layers.89.self_attn_layer_norm.weight", + "decoder.layers.89.self_attn_layer_norm.bias", + "decoder.layers.89.fc1.weight", + "decoder.layers.89.fc1.bias", + "decoder.layers.89.fc2.weight", + "decoder.layers.89.fc2.bias", + "decoder.layers.89.final_layer_norm.weight", + "decoder.layers.89.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.90.flat_param_0": { + "names": [ + "decoder.layers.90.self_attn.qkv_proj.weight", + "decoder.layers.90.self_attn.qkv_proj.bias", + "decoder.layers.90.self_attn.out_proj.weight", + "decoder.layers.90.self_attn.out_proj.bias", + "decoder.layers.90.self_attn_layer_norm.weight", + "decoder.layers.90.self_attn_layer_norm.bias", + "decoder.layers.90.fc1.weight", + "decoder.layers.90.fc1.bias", + "decoder.layers.90.fc2.weight", + "decoder.layers.90.fc2.bias", + "decoder.layers.90.final_layer_norm.weight", + "decoder.layers.90.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.91.flat_param_0": { + "names": [ + "decoder.layers.91.self_attn.qkv_proj.weight", + "decoder.layers.91.self_attn.qkv_proj.bias", + "decoder.layers.91.self_attn.out_proj.weight", + "decoder.layers.91.self_attn.out_proj.bias", + "decoder.layers.91.self_attn_layer_norm.weight", + "decoder.layers.91.self_attn_layer_norm.bias", + "decoder.layers.91.fc1.weight", + "decoder.layers.91.fc1.bias", + "decoder.layers.91.fc2.weight", + "decoder.layers.91.fc2.bias", + "decoder.layers.91.final_layer_norm.weight", + "decoder.layers.91.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.92.flat_param_0": { + "names": [ + "decoder.layers.92.self_attn.qkv_proj.weight", + "decoder.layers.92.self_attn.qkv_proj.bias", + "decoder.layers.92.self_attn.out_proj.weight", + "decoder.layers.92.self_attn.out_proj.bias", + "decoder.layers.92.self_attn_layer_norm.weight", + "decoder.layers.92.self_attn_layer_norm.bias", + "decoder.layers.92.fc1.weight", + "decoder.layers.92.fc1.bias", + "decoder.layers.92.fc2.weight", + "decoder.layers.92.fc2.bias", + "decoder.layers.92.final_layer_norm.weight", + "decoder.layers.92.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.93.flat_param_0": { + "names": [ + "decoder.layers.93.self_attn.qkv_proj.weight", + "decoder.layers.93.self_attn.qkv_proj.bias", + "decoder.layers.93.self_attn.out_proj.weight", + "decoder.layers.93.self_attn.out_proj.bias", + "decoder.layers.93.self_attn_layer_norm.weight", + "decoder.layers.93.self_attn_layer_norm.bias", + "decoder.layers.93.fc1.weight", + "decoder.layers.93.fc1.bias", + "decoder.layers.93.fc2.weight", + "decoder.layers.93.fc2.bias", + "decoder.layers.93.final_layer_norm.weight", + "decoder.layers.93.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.94.flat_param_0": { + "names": [ + "decoder.layers.94.self_attn.qkv_proj.weight", + "decoder.layers.94.self_attn.qkv_proj.bias", + "decoder.layers.94.self_attn.out_proj.weight", + "decoder.layers.94.self_attn.out_proj.bias", + "decoder.layers.94.self_attn_layer_norm.weight", + "decoder.layers.94.self_attn_layer_norm.bias", + "decoder.layers.94.fc1.weight", + "decoder.layers.94.fc1.bias", + "decoder.layers.94.fc2.weight", + "decoder.layers.94.fc2.bias", + "decoder.layers.94.final_layer_norm.weight", + "decoder.layers.94.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.95.flat_param_0": { + "names": [ + "decoder.layers.95.self_attn.qkv_proj.weight", + "decoder.layers.95.self_attn.qkv_proj.bias", + "decoder.layers.95.self_attn.out_proj.weight", + "decoder.layers.95.self_attn.out_proj.bias", + "decoder.layers.95.self_attn_layer_norm.weight", + "decoder.layers.95.self_attn_layer_norm.bias", + "decoder.layers.95.fc1.weight", + "decoder.layers.95.fc1.bias", + "decoder.layers.95.fc2.weight", + "decoder.layers.95.fc2.bias", + "decoder.layers.95.final_layer_norm.weight", + "decoder.layers.95.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + } +} diff --git a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py index 0494647d7bcc..576daacdb471 100644 --- a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py +++ b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py @@ -1,7 +1,8 @@ import os -import torch from multiprocessing import Pool +import torch + # download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main # you can use whether wget or git lfs @@ -20,14 +21,14 @@ restored = {} for ckpt in ckpts: - for k,v in ckpt.items(): - if(k[0] == 'm'): - k = k[6:] - if(k == "lm_head.weight"): + for k, v in ckpt.items(): + if k[0] == "m": + k = k[6:] + if k == "lm_head.weight": k = "head.dense.weight" - if(k == "decoder.final_layer_norm.weight"): + if k == "decoder.final_layer_norm.weight": k = "decoder.layer_norm.weight" - if(k == "decoder.final_layer_norm.bias"): + if k == "decoder.final_layer_norm.bias": k = "decoder.layer_norm.bias" restored[k] = v restored["decoder.version"] = "0.0" @@ -37,11 +38,11 @@ count = 0 file_count = 1 tmp = {} -for k,v in restored.items(): +for k, v in restored.items(): print(k) tmp[k] = v - count = count + 1 - if(count == split_num): + count = count + 1 + if count == split_num: filename = str(file_count) + "-restored.pt" torch.save(tmp, os.path.join(new_path, filename)) file_count = file_count + 1 @@ -50,6 +51,3 @@ filename = str(file_count) + "-restored.pt" torch.save(tmp, os.path.join(new_path, filename)) - - - diff --git a/examples/tutorial/opt/opt/colossalai_zero.py b/examples/tutorial/opt/opt/colossalai_zero.py index 8fbed6e83d52..75516bba560f 100644 --- a/examples/tutorial/opt/opt/colossalai_zero.py +++ b/examples/tutorial/opt/opt/colossalai_zero.py @@ -4,7 +4,7 @@ # colossalai > 0.2.8 from colossalai.legacy.zero import TensorShardStrategy -zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), - tensor_placement_policy="auto", - reuse_fp16_shard=True), - optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384)) +zero = dict( + model_config=dict(shard_strategy=TensorShardStrategy(), tensor_placement_policy="auto", reuse_fp16_shard=True), + optimizer_config=dict(gpu_margin_mem_ratio=0.8, initial_scale=16384), +) diff --git a/examples/tutorial/opt/opt/context.py b/examples/tutorial/opt/opt/context.py index dfcd3b382d3c..7172408f3cbc 100644 --- a/examples/tutorial/opt/opt/context.py +++ b/examples/tutorial/opt/opt/context.py @@ -4,7 +4,7 @@ from colossalai.legacy.core import global_context as gpc -class barrier_context(): +class barrier_context: """ This context manager is used to allow one process to execute while blocking all other processes in the same process group. This is often useful when downloading is required diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 8cbf3d2a2850..9bd23ffc8aba 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -86,14 +86,12 @@ def parse_args(): default=None, help="The configuration name of the dataset to use (via the datasets library).", ) - parser.add_argument("--train_file", - type=str, - default=None, - help="A csv or a json file containing the training data.") - parser.add_argument("--validation_file", - type=str, - default=None, - help="A csv or a json file containing the validation data.") + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) parser.add_argument( "--validation_split_percentage", default=5, @@ -161,10 +159,9 @@ def parse_args(): help="The scheduler type to use.", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], ) - parser.add_argument("--num_warmup_steps", - type=int, - default=0, - help="Number of steps for the warmup in the lr scheduler.") + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( @@ -178,9 +175,11 @@ def parse_args(): "--block_size", type=int, default=None, - help=("Optional input sequence length after tokenization. The training dataset will be truncated in block of" - " this size for training. Default to the model max input length for single sentence inputs (take into" - " account special tokens)."), + help=( + "Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)." + ), ) parser.add_argument( "--preprocessing_num_workers", @@ -188,17 +187,16 @@ def parse_args(): default=None, help="The number of processes to use for the preprocessing.", ) - parser.add_argument("--overwrite_cache", - type=bool, - default=False, - help="Overwrite the cached training and evaluation sets") - parser.add_argument("--no_keep_linebreaks", - action="store_true", - help="Do not keep line breaks when using TXT files.") + parser.add_argument( + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." + ) parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_model_id", - type=str, - help="The name of the repository to keep in sync with the local `output_dir`.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") parser.add_argument( "--checkpointing_steps", @@ -221,13 +219,15 @@ def parse_args(): "--report_to", type=str, default="all", - help=('The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' - "Only applicable when `--with_tracking` is passed."), + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), ) parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap") - parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu") + parser.add_argument("--init_in_cpu", action="store_true", default=False, help="init training model in cpu") args = parser.parse_args() # Sanity checks @@ -250,6 +250,7 @@ def parse_args(): def colo_memory_cap(size_in_GB): from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device + cuda_capacity = colo_device_memory_capacity(get_current_device()) if size_in_GB * (1024**3) < cuda_capacity: colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity) @@ -257,7 +258,6 @@ def colo_memory_cap(size_in_GB): class DummyDataloader: - def __init__(self, length, batch_size, seq_len, vocab_size): self.length = length self.batch_size = batch_size @@ -380,32 +380,36 @@ def main(): logger.warning("You are instantiating a new config instance from scratch.") logger.info("Model config has been created", ranks=[0]) - if args.model_name_or_path == 'facebook/opt-13b': + if args.model_name_or_path == "facebook/opt-13b": tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) else: - print(f'load model from {args.model_name_or_path}') + print(f"load model from {args.model_name_or_path}") tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) logger.info(f"{tokenizer.__class__.__name__} has been created", ranks=[0]) if args.init_in_cpu: - init_dev = torch.device('cpu') + init_dev = torch.device("cpu") else: init_dev = get_current_device() cai_version = colossalai.__version__ - logger.info(f'using Colossal-AI version {cai_version}') + logger.info(f"using Colossal-AI version {cai_version}") # build model if version.parse(cai_version) >= version.parse("0.3.1"): from contextlib import nullcontext from colossalai.lazy import LazyInitContext - ctx = LazyInitContext( - default_device=init_dev - ) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext() + + ctx = ( + LazyInitContext(default_device=init_dev) + if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b" + else nullcontext() + ) else: from colossalai.zero import ColoInitContext + ctx = ColoInitContext(device=init_dev) - if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': + if args.model_name_or_path is None or args.model_name_or_path == "facebook/opt-13b": # currently, there has a bug in pretrained opt-13b # we can not import it until huggingface fix it logger.info("Train a new model from scratch", ranks=[0]) @@ -414,17 +418,20 @@ def main(): else: logger.info("Finetune a pre-trained model", ranks=[0]) with ctx: - model = OPTForCausalLM.from_pretrained(args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - local_files_only=False) + model = OPTForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + local_files_only=False, + ) # enable graident checkpointing model.gradient_checkpointing_enable() - PLACEMENT_POLICY = 'auto' + PLACEMENT_POLICY = "auto" if version.parse(cai_version) >= version.parse("0.3.1"): from colossalai.zero import GeminiDDP + model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True) elif version.parse(cai_version) > version.parse("0.1.10"): try: @@ -435,16 +442,19 @@ def main(): model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager + pg = ProcessGroup() chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) + chunk_manager = ChunkManager( + chunk_size, + pg, + enable_distributed_storage=True, + init_device=GeminiManager.get_default_device(PLACEMENT_POLICY), + ) gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) model = ZeroDDP(model, gemini_manager) - logger.info(f'{model.__class__.__name__} has been created', ranks=[0]) + logger.info(f"{model.__class__.__name__} has been created", ranks=[0]) if not args.synthetic: # Preprocessing the datasets. @@ -470,12 +480,15 @@ def tokenize_function(examples): if block_size > 1024: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --block_size xxx.") + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) block_size = 1024 else: if args.block_size > tokenizer.model_max_length: - logger.warning(f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" - f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.") + logger.warning( + f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) block_size = min(args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. @@ -489,8 +502,8 @@ def group_texts(examples): total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { - k: [t[i:i + block_size] for i in range(0, total_length, block_size) - ] for k, t in concatenated_examples.items() + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result @@ -520,19 +533,23 @@ def group_texts(examples): # logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # DataLoaders creation: - train_dataloader = get_dataloader(train_dataset, - shuffle=True, - add_sampler=True, - collate_fn=default_data_collator, - batch_size=args.per_device_train_batch_size) - eval_dataloader = DataLoader(eval_dataset, - collate_fn=default_data_collator, - batch_size=args.per_device_eval_batch_size) + train_dataloader = get_dataloader( + train_dataset, + shuffle=True, + add_sampler=True, + collate_fn=default_data_collator, + batch_size=args.per_device_train_batch_size, + ) + eval_dataloader = DataLoader( + eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size + ) else: - train_dataloader = DummyDataloader(30, args.per_device_train_batch_size, config.max_position_embeddings, - config.vocab_size) - eval_dataloader = DummyDataloader(10, args.per_device_train_batch_size, config.max_position_embeddings, - config.vocab_size) + train_dataloader = DummyDataloader( + 30, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size + ) + eval_dataloader = DummyDataloader( + 10, args.per_device_train_batch_size, config.max_position_embeddings, config.vocab_size + ) logger.info("Dataloaders have been created", ranks=[0]) # Optimizer @@ -593,7 +610,6 @@ def group_texts(examples): global_step = 0 for epoch in range(starting_epoch, args.num_train_epochs): - if completed_steps >= args.max_train_steps: break @@ -601,7 +617,7 @@ def group_texts(examples): for step, batch in enumerate(train_dataloader): batch = {k: v.cuda() for k, v in batch.items()} outputs = model(use_cache=False, **batch) - loss = outputs['loss'] + loss = outputs["loss"] optimizer.backward(loss) if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: @@ -624,7 +640,7 @@ def group_texts(examples): batch = {k: v.cuda() for k, v in batch.items()} outputs = model(**batch) - loss = outputs['loss'].unsqueeze(0) + loss = outputs["loss"].unsqueeze(0) losses.append(loss) losses = torch.cat(losses) @@ -640,7 +656,7 @@ def group_texts(examples): if args.output_dir is not None: model_state = model.state_dict() if is_main_process: - torch.save(model_state, args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) + torch.save(model_state, args.output_dir + "/epoch_{}_model.pth".format(completed_steps)) dist.barrier() # load_state = torch.load(args.output_dir + '/epoch_{}_model.pth'.format(completed_steps)) # model.load_state_dict(load_state, strict=False) diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py index 887de7164e12..859f6e25e845 100644 --- a/examples/tutorial/sequence_parallel/config.py +++ b/examples/tutorial/sequence_parallel/config.py @@ -4,7 +4,7 @@ TRAIN_ITERS = 10 DECAY_ITERS = 4 WARMUP_FRACTION = 0.01 -GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU +GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU EVAL_ITERS = 10 EVAL_INTERVAL = 10 LR = 0.0001 @@ -28,8 +28,8 @@ NUM_MICRO_BATCHES = 4 # colossalai config -parallel = dict(pipeline=1, tensor=dict(size=2, mode='sequence')) +parallel = dict(pipeline=1, tensor=dict(size=2, mode="sequence")) fp16 = dict(mode=AMP_TYPE.NAIVE, verbose=True) -gradient_handler = [dict(type='SequenceParallelGradientHandler')] +gradient_handler = [dict(type="SequenceParallelGradientHandler")] diff --git a/examples/tutorial/sequence_parallel/data/__init__.py b/examples/tutorial/sequence_parallel/data/__init__.py index 6fdf07ba5b69..137f3cf0267b 100644 --- a/examples/tutorial/sequence_parallel/data/__init__.py +++ b/examples/tutorial/sequence_parallel/data/__init__.py @@ -15,16 +15,13 @@ def cyclic_iter(iter): yield x -def build_train_valid_test_data_iterators(train_iters, - global_batch_size, - eval_interval, - eval_iters, - dataloader_type='single', - **kwargs): +def build_train_valid_test_data_iterators( + train_iters, global_batch_size, eval_interval, eval_iters, dataloader_type="single", **kwargs +): (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) logger = get_dist_logger() - logger.info('> building train, validation, and test datasets ...', ranks=[0]) + logger.info("> building train, validation, and test datasets ...", ranks=[0]) # Backward compatibility, assume fixed batch size. # if iteration > 0 and consumed_train_samples == 0: @@ -38,29 +35,29 @@ def build_train_valid_test_data_iterators(train_iters, # Data loader only on rank 0 of each model parallel group. if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # Number of train/valid/test samples. train_samples = train_iters * global_batch_size eval_iters_ = (train_iters // eval_interval + 1) * eval_iters test_iters = eval_iters train_val_test_num_samples = [train_samples, eval_iters_ * global_batch_size, test_iters * global_batch_size] - logger.info(' > datasets target sizes (minimum size):') - logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0]) - logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) - logger.info(' test: {}'.format(train_val_test_num_samples[2]), ranks=[0]) + logger.info(" > datasets target sizes (minimum size):") + logger.info(" train: {}".format(train_val_test_num_samples[0]), ranks=[0]) + logger.info(" validation: {}".format(train_val_test_num_samples[1]), ranks=[0]) + logger.info(" test: {}".format(train_val_test_num_samples[2]), ranks=[0]) # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets( - train_valid_test_num_samples=train_val_test_num_samples, **kwargs) + train_valid_test_num_samples=train_val_test_num_samples, **kwargs + ) # Build dataloaders. dp_size = gpc.get_world_size(ParallelMode.DATA) - train_dataloader = build_pretraining_data_loader(train_ds, - consumed_samples=0, - micro_batch_size=global_batch_size // dp_size) - valid_dataloader = build_pretraining_data_loader(valid_ds, - consumed_samples=0, - micro_batch_size=global_batch_size // dp_size) + train_dataloader = build_pretraining_data_loader( + train_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size + ) + valid_dataloader = build_pretraining_data_loader( + valid_ds, consumed_samples=0, micro_batch_size=global_batch_size // dp_size + ) test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size // dp_size) # Flags to know if we need to do training/validation/testing. @@ -73,29 +70,26 @@ def build_train_valid_test_data_iterators(train_iters, flags = torch.cuda.LongTensor([0, 0, 0]) # Broadcast num tokens. - torch.distributed.broadcast(flags, - gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], - group=gpc.get_group(ParallelMode.TENSOR)) + torch.distributed.broadcast( + flags, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) + ) # Build iterators. dl_type = dataloader_type - assert dl_type in ['single', 'cyclic'] + assert dl_type in ["single", "cyclic"] if train_dataloader is not None: - train_data_iterator = iter(train_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(train_dataloader)) + train_data_iterator = iter(train_dataloader) if dl_type == "single" else iter(cyclic_iter(train_dataloader)) else: train_data_iterator = None if valid_dataloader is not None: - valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(valid_dataloader)) + valid_data_iterator = iter(valid_dataloader) if dl_type == "single" else iter(cyclic_iter(valid_dataloader)) else: valid_data_iterator = None if test_dataloader is not None: - test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(test_dataloader)) + test_data_iterator = iter(test_dataloader) if dl_type == "single" else iter(cyclic_iter(test_dataloader)) else: test_data_iterator = None diff --git a/examples/tutorial/sequence_parallel/data/bert_helper.py b/examples/tutorial/sequence_parallel/data/bert_helper.py index b65ca1e64f3c..471be19bb123 100644 --- a/examples/tutorial/sequence_parallel/data/bert_helper.py +++ b/examples/tutorial/sequence_parallel/data/bert_helper.py @@ -15,7 +15,7 @@ def _build_key_size_numel_dictionaries(keys, data): if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: offset = 0 for key in keys: - assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" size = data[key].size() for i, s in enumerate(size): sizes[i + offset] = s @@ -23,9 +23,9 @@ def _build_key_size_numel_dictionaries(keys, data): # Move to GPU and broadcast. sizes_cuda = torch.cuda.LongTensor(sizes) - torch.distributed.broadcast(sizes_cuda, - gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], - group=gpc.get_group(ParallelMode.TENSOR)) + torch.distributed.broadcast( + sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) + ) # Move back to cpu and unpack. sizes_cpu = sizes_cuda.cpu() @@ -73,9 +73,9 @@ def broadcast_data(keys, data, datatype): flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) # Broadcast - torch.distributed.broadcast(flatten_data, - gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], - group=gpc.get_group(ParallelMode.TENSOR)) + torch.distributed.broadcast( + flatten_data, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR) + ) # Unpack output = {} @@ -93,7 +93,7 @@ def get_batch(data_iterator): """Build the batch.""" # Items and their type. - keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] datatype = torch.int64 # Broadcast data. @@ -104,12 +104,12 @@ def get_batch(data_iterator): data_b = broadcast_data(keys, data, datatype) # Unpack. - tokens = data_b['text'].long() - types = data_b['types'].long() - sentence_order = data_b['is_random'].long() - loss_mask = data_b['loss_mask'].float() - lm_labels = data_b['labels'].long() - padding_mask = data_b['padding_mask'].long() + tokens = data_b["text"].long() + types = data_b["types"].long() + sentence_order = data_b["is_random"].long() + loss_mask = data_b["loss_mask"].float() + lm_labels = data_b["labels"].long() + padding_mask = data_b["padding_mask"].long() return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask @@ -118,7 +118,7 @@ def get_batch_for_sequence_parallel(data_iterator): """Build the batch.""" # Items and their type. - keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] + keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] datatype = torch.int64 # Broadcast data. @@ -134,24 +134,23 @@ def get_batch_for_sequence_parallel(data_iterator): global_rank = torch.distributed.get_rank() local_world_size = 1 if not gpc.is_initialized(ParallelMode.TENSOR) else gpc.get_world_size(ParallelMode.TENSOR) local_rank = global_rank % local_world_size - seq_length = data_b['text'].size(1) + seq_length = data_b["text"].size(1) sub_seq_length = seq_length // local_world_size sub_seq_start = local_rank * sub_seq_length sub_seq_end = (local_rank + 1) * sub_seq_length # # # Unpack. - tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() - types = data_b['types'][:, sub_seq_start:sub_seq_end].long() - sentence_order = data_b['is_random'].long() - loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float() - lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long() - padding_mask = data_b['padding_mask'].long() + tokens = data_b["text"][:, sub_seq_start:sub_seq_end].long() + types = data_b["types"][:, sub_seq_start:sub_seq_end].long() + sentence_order = data_b["is_random"].long() + loss_mask = data_b["loss_mask"][:, sub_seq_start:sub_seq_end].float() + lm_labels = data_b["labels"][:, sub_seq_start:sub_seq_end].long() + padding_mask = data_b["padding_mask"].long() return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask class SequenceParallelDataIterator: - def __init__(self, data_iter): self.data_iter = data_iter diff --git a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py index 70c1269122dc..afab202e0927 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/bert_dataset.py @@ -41,10 +41,19 @@ class BertDataset(Dataset): - - def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, - short_seq_prob, seed, binary_head): - + def __init__( + self, + name, + indexed_dataset, + data_prefix, + num_epochs, + max_num_samples, + masked_lm_prob, + max_seq_length, + short_seq_prob, + seed, + binary_head, + ): # Params to store. self.name = name self.seed = seed @@ -61,11 +70,12 @@ def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_sampl data_prefix, num_epochs, max_num_samples, - self.max_seq_length - 3, # account for added tokens, + self.max_seq_length - 3, # account for added tokens, short_seq_prob, self.seed, self.name, - self.binary_head) + self.binary_head, + ) # Vocab stuff. tokenizer = get_tokenizer() @@ -89,7 +99,7 @@ def __getitem__(self, idx): return build_training_sample( sample, seq_length, - self.max_seq_length, # needed for padding + self.max_seq_length, # needed for padding self.vocab_id_list, self.vocab_id_to_token_dict, self.cls_id, @@ -98,37 +108,39 @@ def __getitem__(self, idx): self.pad_id, self.masked_lm_prob, np_rng, - self.binary_head) + self.binary_head, + ) -def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, - seed, name, binary_head): +def get_samples_mapping_( + indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name, binary_head +): logger = get_dist_logger() if not num_epochs: if not max_num_samples: - raise ValueError("Need to specify either max_num_samples " - "or num_epochs") + raise ValueError("Need to specify either max_num_samples " "or num_epochs") num_epochs = np.iinfo(np.int32).max - 1 if not max_num_samples: max_num_samples = np.iinfo(np.int64).max - 1 # Filename of the index mapping indexmap_filename = data_prefix - indexmap_filename += '_{}_indexmap'.format(name) + indexmap_filename += "_{}_indexmap".format(name) if num_epochs != (np.iinfo(np.int32).max - 1): - indexmap_filename += '_{}ep'.format(num_epochs) + indexmap_filename += "_{}ep".format(num_epochs) if max_num_samples != (np.iinfo(np.int64).max - 1): - indexmap_filename += '_{}mns'.format(max_num_samples) - indexmap_filename += '_{}msl'.format(max_seq_length) - indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) - indexmap_filename += '_{}s'.format(seed) - indexmap_filename += '.npy' + indexmap_filename += "_{}mns".format(max_num_samples) + indexmap_filename += "_{}msl".format(max_seq_length) + indexmap_filename += "_{:0.2f}ssp".format(short_seq_prob) + indexmap_filename += "_{}s".format(seed) + indexmap_filename += ".npy" # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0 and \ - not os.path.isfile(indexmap_filename): - print(' > WARNING: could not find index map file {}, building ' - 'the indices on rank 0 ...'.format(indexmap_filename)) + if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename): + print( + " > WARNING: could not find index map file {}, building " + "the indices on rank 0 ...".format(indexmap_filename) + ) # Make sure the types match the helpers input types. assert indexed_dataset.doc_idx.dtype == np.int64 @@ -137,18 +149,27 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl # Build samples mapping verbose = torch.distributed.get_rank() == 0 start_time = time.time() - logger.info('\n > building samples index mapping for {} ...'.format(name), ranks=[0]) + logger.info("\n > building samples index mapping for {} ...".format(name), ranks=[0]) # First compile and then import. - samples_mapping = helpers.build_mapping(indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, - max_num_samples, max_seq_length, short_seq_prob, seed, verbose, - 2 if binary_head else 1) - logger.info('\n > done building samples index maping', ranks=[0]) + samples_mapping = helpers.build_mapping( + indexed_dataset.doc_idx, + indexed_dataset.sizes, + num_epochs, + max_num_samples, + max_seq_length, + short_seq_prob, + seed, + verbose, + 2 if binary_head else 1, + ) + logger.info("\n > done building samples index maping", ranks=[0]) np.save(indexmap_filename, samples_mapping, allow_pickle=True) - logger.info('\n > saved the index mapping in {}'.format(indexmap_filename), ranks=[0]) + logger.info("\n > saved the index mapping in {}".format(indexmap_filename), ranks=[0]) # Make sure all the ranks have built the mapping - logger.info('\n > elapsed time to build and save samples mapping ' - '(seconds): {:4f}'.format(time.time() - start_time), - ranks=[0]) + logger.info( + "\n > elapsed time to build and save samples mapping " "(seconds): {:4f}".format(time.time() - start_time), + ranks=[0], + ) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model # parallel case @@ -156,22 +177,38 @@ def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_sampl torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA)) if gpc.is_initialized(ParallelMode.PIPELINE): torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE)) - assert counts[0].item() == (torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE))) + assert counts[0].item() == ( + torch.distributed.get_world_size() + // torch.distributed.get_world_size(group=gpc.get_group(ParallelMode.SEQUENCE)) + ) # Load indexed dataset. start_time = time.time() - samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') - logger.info('\n > loading indexed mapping from {}'.format(indexmap_filename) + - '\n loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time) + - '\n total number of samples: {}'.format(samples_mapping.shape[0]), - ranks=[0]) + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode="r") + logger.info( + "\n > loading indexed mapping from {}".format(indexmap_filename) + + "\n loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + + "\n total number of samples: {}".format(samples_mapping.shape[0]), + ranks=[0], + ) return samples_mapping -def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_list, vocab_id_to_token_dict, cls_id, - sep_id, mask_id, pad_id, masked_lm_prob, np_rng, binary_head): +def build_training_sample( + sample, + target_seq_length, + max_seq_length, + vocab_id_list, + vocab_id_to_token_dict, + cls_id, + sep_id, + mask_id, + pad_id, + masked_lm_prob, + np_rng, + binary_head, +): """Build training sample. Arguments: @@ -215,22 +252,30 @@ def build_training_sample(sample, target_seq_length, max_seq_length, vocab_id_li # Masking. max_predictions_per_seq = masked_lm_prob * max_num_tokens - (tokens, masked_positions, masked_labels, - _) = create_masked_lm_predictions(tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, cls_id, sep_id, - mask_id, max_predictions_per_seq, np_rng) + (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + ) # Padding. - tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ - = pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length) + tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np = pad_and_convert_to_numpy( + tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length + ) train_sample = { - 'text': tokens_np, - 'types': tokentypes_np, - 'labels': labels_np, - 'is_random': int(is_next_random), - 'loss_mask': loss_mask_np, - 'padding_mask': padding_mask_np, - 'truncated': int(truncated) + "text": tokens_np, + "types": tokentypes_np, + "labels": labels_np, + "is_random": int(is_next_random), + "loss_mask": loss_mask_np, + "padding_mask": padding_mask_np, + "truncated": int(truncated), } return train_sample diff --git a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py index 6a06c869d8c8..1fa9c85fce0a 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py @@ -22,9 +22,7 @@ class BlendableDataset(torch.utils.data.Dataset): - def __init__(self, datasets, weights): - self.datasets = datasets num_datasets = len(datasets) assert num_datasets == len(weights) @@ -46,12 +44,16 @@ def __init__(self, datasets, weights): self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) from . import helpers - helpers.build_blending_indices(self.dataset_index, - self.dataset_sample_index, - weights, num_datasets, self.size, - torch.distributed.get_rank() == 0) - print('> elapsed time for building blendable dataset indices: ' - '{:.2f} (sec)'.format(time.time() - start_time)) + + helpers.build_blending_indices( + self.dataset_index, + self.dataset_sample_index, + weights, + num_datasets, + self.size, + torch.distributed.get_rank() == 0, + ) + print("> elapsed time for building blendable dataset indices: " "{:.2f} (sec)".format(time.time() - start_time)) def __len__(self): return self.size diff --git a/examples/tutorial/sequence_parallel/data/datasets/builder.py b/examples/tutorial/sequence_parallel/data/datasets/builder.py index 6106f833b462..edf4c3d70cbf 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/builder.py +++ b/examples/tutorial/sequence_parallel/data/datasets/builder.py @@ -1,29 +1,34 @@ +from colossalai.logging import get_dist_logger + +from .bert_dataset import BertDataset from .blendable_dataset import BlendableDataset from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ -from .bert_dataset import BertDataset -from colossalai.logging import get_dist_logger -DSET_TYPE_BERT = 'standard_bert' -DSET_TYPE_ICT = 'ict' -DSET_TYPE_T5 = 't5' +DSET_TYPE_BERT = "standard_bert" +DSET_TYPE_ICT = "ict" +DSET_TYPE_T5 = "t5" DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): - +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is designed to be num-docs + 1 so we can @@ -34,22 +39,25 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, logger = get_dist_logger() # Print stats about the splits. - logger.info('\n > dataset split:', ranks=[0]) + logger.info("\n > dataset split:", ranks=[0]) def print_split_stats(name, index): start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] - logger.info('\n {}:'.format(name) + - '\n document indices in [{}, {}) total of {} documents'.format( - splits[index], splits[index + 1], - splits[index + 1] - splits[index]) + - '\n sentence indices in [{}, {}) total of {} sentences'.format( - start_index, end_index, - end_index - start_index), - ranks=[0]) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + logger.info( + "\n {}:".format(name) + + "\n document indices in [{}, {}) total of {} documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + + "\n sentence indices in [{}, {}) total of {} sentences".format( + start_index, end_index, end_index - start_index + ), + ranks=[0], + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): dataset = None @@ -80,44 +88,53 @@ def build_dataset(index, name): masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, binary_head=binary_head, - **kwargs + **kwargs, ) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) # Checks. assert indexed_dataset.doc_idx[0] == 0 - assert indexed_dataset.doc_idx.shape[0] == \ - (total_num_of_documents + 1) + assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") return (train_dataset, valid_dataset, test_dataset) -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): - +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, - skip_warmup, - binary_head, - dataset_type=dataset_type) + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) # Blending dataset. # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. @@ -126,10 +143,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - prefixes[i], data_impl, splits_string, + prefixes[i], + data_impl, + splits_string, datasets_train_valid_test_num_samples[i], - max_seq_length, masked_lm_prob, short_seq_prob, - seed, skip_warmup, binary_head, dataset_type=dataset_type) + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -148,5 +173,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) diff --git a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py index b9c197c95ae3..8ba598529ebc 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py +++ b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py @@ -14,7 +14,6 @@ # limitations under the License. """Dataloaders.""" -import random import torch @@ -22,61 +21,60 @@ from colossalai.legacy.core import global_context as gpc -def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): +def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type="single", num_workers=0): """Build dataloader given an input dataset.""" if dataset is None: return None # Megatron sampler - if dataloader_type == 'single': - batch_sampler = MegatronPretrainingSampler(total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=micro_batch_size, - data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), - data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) - elif dataloader_type == 'cyclic': - batch_sampler = MegatronPretrainingRandomSampler(total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=micro_batch_size, - data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), - data_parallel_size=gpc.get_world_size(ParallelMode.DATA)) + if dataloader_type == "single": + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA), + ) + elif dataloader_type == "cyclic": + batch_sampler = MegatronPretrainingRandomSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=gpc.get_local_rank(ParallelMode.DATA), + data_parallel_size=gpc.get_world_size(ParallelMode.DATA), + ) else: - raise Exception('{} dataloader type is not supported.'.format(dataloader_type)) + raise Exception("{} dataloader type is not supported.".format(dataloader_type)) # Torch dataloader. return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=num_workers, pin_memory=True) class MegatronPretrainingSampler: - - def __init__(self, - total_samples, - consumed_samples, - micro_batch_size, - data_parallel_rank, - data_parallel_size, - drop_last=True): + def __init__( + self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True + ): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size self.drop_last = drop_last # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) - assert self.consumed_samples < self.total_samples, \ - 'no samples left to consume: {}, {}'.format(self.consumed_samples, - self.total_samples) + assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) + assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format( + self.consumed_samples, self.total_samples + ) assert self.micro_batch_size > 0 assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) + assert ( + self.data_parallel_rank < data_parallel_size + ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( + self.data_parallel_rank, data_parallel_size + ) def __len__(self): return self.total_samples @@ -103,7 +101,6 @@ def __iter__(self): class MegatronPretrainingRandomSampler: - def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size): # Keep a copy of input params for later use. self.total_samples = total_samples @@ -111,19 +108,18 @@ def __init__(self, total_samples, consumed_samples, micro_batch_size, data_paral self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size - self.last_batch_size = \ - self.total_samples % self.micro_batch_times_data_parallel_size + self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size + self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) + assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) + assert ( + self.data_parallel_rank < data_parallel_size + ), "data_parallel_rank should be smaller than data size: {}, " "{}".format( + self.data_parallel_rank, data_parallel_size + ) def __len__(self): return self.total_samples @@ -135,8 +131,7 @@ def __iter__(self): assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 # data sharding and random sampling - bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ - * self.micro_batch_size + bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size bucket_offset = current_epoch_samples // self.data_parallel_size start_idx = self.data_parallel_rank * bucket_size diff --git a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py index cf4e4763fc10..3e197ff96c0c 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py +++ b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py @@ -18,32 +18,33 @@ # https://github.com/google-research/albert/blob/master/create_pretraining_data.py # with some modifications. +import collections import math import time -import collections -from colossalai.logging import get_dist_logger + import numpy as np + +from colossalai.logging import get_dist_logger + from .blendable_dataset import BlendableDataset from .indexed_dataset import make_dataset as make_indexed_dataset -DSET_TYPE_STD = 'standard_bert' -DSET_TYPE_ICT = 'ict' +DSET_TYPE_STD = "standard_bert" +DSET_TYPE_ICT = "ict" DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] -def get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples): - +def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): # The data prefix should be in the format of: # weight-1, data-prefix-1, weight-2, data-prefix-2, .. assert len(data_prefix) % 2 == 0 num_datasets = len(data_prefix) // 2 - weights = [0]*num_datasets - prefixes = [0]*num_datasets + weights = [0] * num_datasets + prefixes = [0] * num_datasets for i in range(num_datasets): - weights[i] = float(data_prefix[2*i]) - prefixes[i] = (data_prefix[2*i+1]).strip() + weights[i] = float(data_prefix[2 * i]) + prefixes[i] = (data_prefix[2 * i + 1]).strip() # Normalize weights weight_sum = 0.0 for weight in weights: @@ -57,8 +58,8 @@ def get_datasets_weights_and_num_samples(data_prefix, datasets_train_valid_test_num_samples = [] for weight in weights: datasets_train_valid_test_num_samples.append( - [int(math.ceil(val * weight * 1.005)) - for val in train_valid_test_num_samples]) + [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples] + ) return prefixes, weights, datasets_train_valid_test_num_samples @@ -68,11 +69,13 @@ def compile_helper(): is invoked on a single process.""" import os import subprocess + path = os.path.abspath(os.path.dirname(__file__)) - ret = subprocess.run(['make', '-C', path]) + ret = subprocess.run(["make", "-C", path]) if ret.returncode != 0: print("Making C++ dataset helpers module failed, exiting.") import sys + sys.exit(1) @@ -82,7 +85,7 @@ def get_a_and_b_segments(sample, np_rng): # Number of sentences in the sample. n_sentences = len(sample) # Make sure we always have two sentences. - assert n_sentences > 1, 'make sure each sample has at least two sentences.' + assert n_sentences > 1, "make sure each sample has at least two sentences." # First part: # `a_end` is how many sentences go into the `A`. @@ -110,7 +113,7 @@ def get_a_and_b_segments(sample, np_rng): def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): """Truncates a pair of sequences to a maximum sequence length.""" - #print(len_a, len_b, max_num_tokens) + # print(len_a, len_b, max_num_tokens) assert len_a > 0 if len_a + len_b <= max_num_tokens: return False @@ -155,8 +158,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): return tokens, tokentypes -MaskedLmInstance = collections.namedtuple("MaskedLmInstance", - ["index", "label"]) +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) def is_start_piece(piece): @@ -168,16 +170,21 @@ def is_start_piece(piece): return not piece.startswith("##") -def create_masked_lm_predictions(tokens, - vocab_id_list, vocab_id_to_token_dict, - masked_lm_prob, - cls_id, sep_id, mask_id, - max_predictions_per_seq, - np_rng, - max_ngrams=3, - do_whole_word_mask=True, - favor_longer_ngram=False, - do_permutation=False): +def create_masked_lm_predictions( + tokens, + vocab_id_list, + vocab_id_to_token_dict, + masked_lm_prob, + cls_id, + sep_id, + mask_id, + max_predictions_per_seq, + np_rng, + max_ngrams=3, + do_whole_word_mask=True, + favor_longer_ngram=False, + do_permutation=False, +): """Creates the predictions for the masked LM objective. Note: Tokens here are vocab ids and not text tokens.""" @@ -187,7 +194,7 @@ def create_masked_lm_predictions(tokens, # on-the-fly whole word masking is possible. token_boundary = [0] * len(tokens) - for (i, token) in enumerate(tokens): + for i, token in enumerate(tokens): if token == cls_id or token == sep_id: token_boundary[i] = 1 continue @@ -197,8 +204,7 @@ def create_masked_lm_predictions(tokens, # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. - if (do_whole_word_mask and len(cand_indexes) >= 1 and - not is_start_piece(vocab_id_to_token_dict[token])): + if do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token]): cand_indexes[-1].append(i) else: cand_indexes.append([i]) @@ -211,16 +217,14 @@ def create_masked_lm_predictions(tokens, masked_lm_labels = [] if masked_lm_prob == 0: - return (output_tokens, masked_lm_positions, - masked_lm_labels, token_boundary) + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) - num_to_predict = min(max_predictions_per_seq, - max(1, int(round(len(tokens) * masked_lm_prob)))) + num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) # Note(mingdachen): # By default, we set the probabilities to favor shorter ngram sequences. ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) - pvals = 1. / np.arange(1, max_ngrams + 1) + pvals = 1.0 / np.arange(1, max_ngrams + 1) pvals /= pvals.sum(keepdims=True) if favor_longer_ngram: @@ -230,7 +234,7 @@ def create_masked_lm_predictions(tokens, for idx in range(len(cand_indexes)): ngram_index = [] for n in ngrams: - ngram_index.append(cand_indexes[idx:idx + n]) + ngram_index.append(cand_indexes[idx : idx + n]) ngram_indexes.append(ngram_index) np_rng.shuffle(ngram_indexes) @@ -249,9 +253,10 @@ def create_masked_lm_predictions(tokens, if index in covered_indexes: continue - n = np_rng.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) + n = np_rng.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), + ) index_set = sum(cand_index_set[n - 1], []) n -= 1 # Note(mingdachen): @@ -309,9 +314,10 @@ def create_masked_lm_predictions(tokens, if index in covered_indexes or index in select_indexes: continue - n = np.random.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) + n = np.random.choice( + ngrams[: len(cand_index_set)], + p=pvals[: len(cand_index_set)] / pvals[: len(cand_index_set)].sum(keepdims=True), + ) index_set = sum(cand_index_set[n - 1], []) n -= 1 @@ -353,8 +359,7 @@ def create_masked_lm_predictions(tokens, return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) -def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length): +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length): """Pad sequences and convert them to numpy.""" # Some checks. @@ -370,8 +375,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) # Padding mask. - padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, - dtype=np.int64) + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) # Lables and loss mask. labels = [-1] * max_seq_length @@ -386,26 +390,36 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): - +def build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, - skip_warmup, - binary_head, - dataset_type=dataset_type) + return _build_train_valid_test_datasets( + data_prefix[0], + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) # Blending dataset. # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. @@ -414,10 +428,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, test_datasets = [] for i in range(len(prefixes)): train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - prefixes[i], data_impl, splits_string, + prefixes[i], + data_impl, + splits_string, datasets_train_valid_test_num_samples[i], - max_seq_length, masked_lm_prob, short_seq_prob, - seed, skip_warmup, binary_head, dataset_type=dataset_type) + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type, + ) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -436,31 +458,33 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) - - -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, - train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, - binary_head, - dataset_type='standard_bert'): + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) + + +def _build_train_valid_test_datasets( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type="standard_bert", +): logger = get_dist_logger() if dataset_type not in DSET_TYPES: raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) if dataset_type == DSET_TYPE_ICT: args = get_args() - title_dataset = get_indexed_dataset_(args.titles_data_path, - data_impl, - skip_warmup) + title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is designed to be num-docs + 1 so we can @@ -469,27 +493,29 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. - logger.info('\n > dataset split:') + logger.info("\n > dataset split:") def print_split_stats(name, index): start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] - logger.info('\n {}:'.format(name) + - '\n document indices in [{}, {}) total of {} documents'.format( - splits[index], - splits[index + 1], - splits[index + 1] - splits[index]) + - '\n sentence indices in [{}, {}) total of {} sentences'.format( - start_index, - end_index, - end_index - start_index), - ranks=[0]) - print_split_stats('train', 0) - print_split_stats('validation', 1) - print_split_stats('test', 2) + logger.info( + "\n {}:".format(name) + + "\n document indices in [{}, {}) total of {} documents".format( + splits[index], splits[index + 1], splits[index + 1] - splits[index] + ) + + "\n sentence indices in [{}, {}) total of {} sentences".format( + start_index, end_index, end_index - start_index + ), + ranks=[0], + ) + + print_split_stats("train", 0) + print_split_stats("validation", 1) + print_split_stats("test", 2) def build_dataset(index, name): from .bert_dataset import BertDataset + dataset = None if splits[index + 1] > splits[index]: # Get the pointer to the original doc-idx so we can set it later. @@ -508,7 +534,7 @@ def build_dataset(index, name): max_num_samples=train_valid_test_num_samples[index], max_seq_length=max_seq_length, seed=seed, - binary_head=binary_head + binary_head=binary_head, ) if dataset_type == DSET_TYPE_ICT: @@ -518,27 +544,26 @@ def build_dataset(index, name): title_dataset=title_dataset, query_in_block_prob=args.query_in_block_prob, use_one_sent_docs=args.use_one_sent_docs, - **kwargs + **kwargs, ) else: dataset = BertDataset( indexed_dataset=indexed_dataset, masked_lm_prob=masked_lm_prob, short_seq_prob=short_seq_prob, - **kwargs + **kwargs, ) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) # Checks. assert indexed_dataset.doc_idx[0] == 0 - assert indexed_dataset.doc_idx.shape[0] == \ - (total_num_of_documents + 1) + assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1) return dataset - train_dataset = build_dataset(0, 'train') - valid_dataset = build_dataset(1, 'valid') - test_dataset = build_dataset(2, 'test') + train_dataset = build_dataset(0, "train") + valid_dataset = build_dataset(1, "valid") + test_dataset = build_dataset(2, "test") return (train_dataset, valid_dataset, test_dataset) @@ -546,44 +571,41 @@ def build_dataset(index, name): def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): logger = get_dist_logger() start_time = time.time() - indexed_dataset = make_indexed_dataset(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] - logger.info('\n > building dataset index ...', ranks=[0]) - logger.info('\n > finished creating indexed dataset in {:4f} ' - 'seconds'.format(time.time() - start_time), ranks=[0]) - logger.info('\n > indexed dataset stats:' + - '\n number of documents: {}'.format( - indexed_dataset.doc_idx.shape[0] - 1) + - '\n number of sentences: {}'.format( - indexed_dataset.sizes.shape[0]), - ranks=[0] - ) + logger.info("\n > building dataset index ...", ranks=[0]) + logger.info( + "\n > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time), ranks=[0] + ) + logger.info( + "\n > indexed dataset stats:" + + "\n number of documents: {}".format(indexed_dataset.doc_idx.shape[0] - 1) + + "\n number of sentences: {}".format(indexed_dataset.sizes.shape[0]), + ranks=[0], + ) return indexed_dataset def get_train_valid_test_split_(splits_string, size): - """ Get dataset splits from comma or '/' separated string list.""" + """Get dataset splits from comma or '/' separated string list.""" splits = [] - if splits_string.find(',') != -1: - splits = [float(s) for s in splits_string.split(',')] - elif splits_string.find('/') != -1: - splits = [float(s) for s in splits_string.split('/')] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] else: splits = [float(splits_string)] while len(splits) < 3: - splits.append(0.) + splits.append(0.0) splits = splits[:3] splits_sum = sum(splits) assert splits_sum > 0.0 splits = [split / splits_sum for split in splits] splits_index = [0] for index, split in enumerate(splits): - splits_index.append(splits_index[index] + - int(round(split * float(size)))) + splits_index.append(splits_index[index] + int(round(split * float(size)))) diff = splits_index[-1] - size for index in range(1, len(splits_index)): splits_index[index] -= diff diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp index e45926a97696..52977e63181f 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp +++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp @@ -15,29 +15,28 @@ limitations under the License. */ - /* Helper methods for fast index mapping builds */ +#include +#include +#include + #include #include #include -#include -#include -#include -#include #include +#include namespace py = pybind11; using namespace std; const int32_t LONG_SENTENCE_LEN = 512; - void build_blending_indices(py::array_t& dataset_index, - py::array_t& dataset_sample_index, - const py::array_t& weights, - const int32_t num_datasets, - const int64_t size, const bool verbose) { + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, const int64_t size, + const bool verbose) { /* Given multiple datasets and a weighting array, build samples such that it follows those wieghts.*/ @@ -52,24 +51,23 @@ void build_blending_indices(py::array_t& dataset_index, // Initialize buffer for number of samples used for each dataset. int64_t current_samples[num_datasets]; - for(int64_t i = 0; i < num_datasets; ++i) { + for (int64_t i = 0; i < num_datasets; ++i) { current_samples[i] = 0; } // For each sample: - for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { - + for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { // Determine where the max error in sampling is happening. auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); int64_t max_error_index = 0; double max_error = weights_ptr[0] * sample_idx_double - - static_cast(current_samples[0]); + static_cast(current_samples[0]); for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { double error = weights_ptr[dataset_idx] * sample_idx_double - - static_cast(current_samples[dataset_idx]); + static_cast(current_samples[dataset_idx]); if (error > max_error) { - max_error = error; - max_error_index = dataset_idx; + max_error = error; + max_error_index = dataset_idx; } } @@ -79,7 +77,6 @@ void build_blending_indices(py::array_t& dataset_index, // Update the total samples. current_samples[max_error_index] += 1; - } // print info @@ -87,631 +84,607 @@ void build_blending_indices(py::array_t& dataset_index, std::cout << " > sample ratios:" << std::endl; for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { auto ratio = static_cast(current_samples[dataset_idx]) / - static_cast(size); - std::cout << " dataset " << dataset_idx << ", input: " << - weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + static_cast(size); + std::cout << " dataset " << dataset_idx + << ", input: " << weights_ptr[dataset_idx] + << ", achieved: " << ratio << std::endl; } } - } - py::array build_sample_idx(const py::array_t& sizes_, - const py::array_t& doc_idx_, - const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch) { - /* Sample index (sample_idx) is used for gpt2 like dataset for which - the documents are flattened and the samples are built based on this - 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] - where [..., 0] contains the index into `doc_idx` and [..., 1] is the - starting offset in that document.*/ - - // Consistency checks. - assert(seq_length > 1); - assert(num_epochs > 0); - assert(tokens_per_epoch > 1); - - // Remove bound checks. - auto sizes = sizes_.unchecked<1>(); - auto doc_idx = doc_idx_.unchecked<1>(); - - // Mapping and it's length (1D). - int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; - int32_t* sample_idx = new int32_t[2*(num_samples+1)]; - - cout << " using:" << endl << std::flush; - cout << " number of documents: " << - doc_idx_.shape(0) / num_epochs << endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " sequence length: " << seq_length << - endl << std::flush; - cout << " total number of samples: " << num_samples << - endl << std::flush; - - // Index into sample_idx. - int64_t sample_index = 0; - // Index into doc_idx. - int64_t doc_idx_index = 0; - // Begining offset for each document. - int32_t doc_offset = 0; - // Start with first document and no offset. + const py::array_t& doc_idx_, + const int32_t seq_length, const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2 * (num_samples + 1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << doc_idx_.shape(0) / num_epochs + << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " sequence length: " << seq_length << endl + << std::flush; + cout << " total number of samples: " << num_samples << endl + << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. sample_idx[2 * sample_index] = doc_idx_index; sample_idx[2 * sample_index + 1] = doc_offset; ++sample_index; + } - while (sample_index <= num_samples) { - // Start with a fresh sequence. - int32_t remaining_seq_length = seq_length + 1; - while (remaining_seq_length != 0) { - // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; - // And add it to the current sequence. - remaining_seq_length -= doc_length; - // If we have more than a full sequence, adjust offset and set - // remaining length to zero so we return from the while loop. - // Note that -1 here is for the same reason we have -1 in - // `_num_epochs` calculations. - if (remaining_seq_length <= 0) { - doc_offset += (remaining_seq_length + doc_length - 1); - remaining_seq_length = 0; - } else { - // Otherwise, start from the begining of the next document. - ++doc_idx_index; - doc_offset = 0; - } - } - // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; - } - - // Method to deallocate memory. - py::capsule free_when_done(sample_idx, [](void *mem_) { - int32_t *mem = reinterpret_cast(mem_); - delete[] mem; - }); - - // Return the numpy array. - const auto byte_size = sizeof(int32_t); - return py::array(std::vector{num_samples+1, 2}, // shape - {2*byte_size, byte_size}, // C-style contiguous strides - sample_idx, // the data pointer - free_when_done); // numpy array references - + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void* mem_) { + int32_t* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples + 1, 2}, // shape + {2 * byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references } - inline int32_t get_target_sample_len(const int32_t short_seq_ratio, - const int32_t max_length, - std::mt19937& rand32_gen) { - /* Training sample length. */ - if (short_seq_ratio == 0) { - return max_length; - } - const auto random_number = rand32_gen(); - if ((random_number % short_seq_ratio) == 0) { - return 2 + random_number % (max_length - 1); - } + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; } - -template +template py::array build_mapping_impl(const py::array_t& docs_, const py::array_t& sizes_, const int32_t num_epochs, const uint64_t max_num_samples, const int32_t max_seq_length, - const double short_seq_prob, - const int32_t seed, - const bool verbose, - const int32_t min_num_sent) { - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(short_seq_prob >= 0.0); - assert(short_seq_prob <= 1.0); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - - // For efficiency, convert probability to ratio. Note: rand() generates int. - int32_t short_seq_ratio = 0; - if (short_seq_prob > 0) { - short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); - } + const double short_seq_prob, const int32_t seed, + const bool verbose, const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } - if (verbose) { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " short sequence probability: " << short_seq_prob << - endl << std::flush; - cout << " short sequence ration (1/prob): " << short_seq_ratio << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; - } + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 + << endl + << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl + << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl + << std::flush; + cout << " short sequence probability: " << short_seq_prob << endl + << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << endl + << std::flush; + cout << " seed: " << seed << endl + << std::flush; + } - // Mapping and it's length (1D). - int64_t num_samples = -1; - DocIdx* maps = NULL; - - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - - // Set the seed so both iterations produce the same results. - std::mt19937 rand32_gen(seed); - - // Set the flag on second iteration. - second = (iteration == 1); - - // Counters: - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - - // Current map index. - uint64_t map_index = 0; - - // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { - if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; - } - break; + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration = 0; iteration < 2; ++iteration) { + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { + if (map_index >= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl + << std::flush; + } + break; + } + // For each document: + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; } - // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } - } - - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent > 1) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - - // If we have more than two sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - auto target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - - // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and if not only one sentence is left in the document. - // and if we have at least two sentneces. - // and if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent > 1) && - (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { - - // Check for overflow. - if ((3 * map_index + 2) > - std::numeric_limits::max()) { - cout << "number of samples exceeded maximum " - << "allowed by type int64: " - << std::numeric_limits::max() - << endl; - throw std::overflow_error("Number of samples"); - } - - // Populate the map. - if (second) { - const auto map_index_0 = 3 * map_index; - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(target_seq_len); - } - - // Update indices / counters. - ++map_index; - prev_start_index = sent_index + 1; - target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - seq_len = 0; - num_sent = 0; - } - - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) { - if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; - } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[3*map_index]; - num_samples = static_cast(map_index); + } } - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 3 * i; - const auto j0 = 3 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - } + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len( + short_seq_ratio, max_seq_length, rand32_gen); + + // Loop through sentences. + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && (num_remain_sent > 1) && + (num_sent >= min_num_sent)) || + (num_remain_sent == 0)) { + // Check for overflow. + if ((3 * map_index + 2) > std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len( + short_seq_ratio, max_seq_length, rand32_gen); + seq_len = 0; + num_sent = 0; + } - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << endl + << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs + << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs + << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3 * map_index]; + num_samples = static_cast(map_index); + } - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 3}, // shape - {3*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i = (num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } - py::array build_mapping(const py::array_t& docs_, - const py::array_t& sizes_, - const int num_epochs, + const py::array_t& sizes_, const int num_epochs, const uint64_t max_num_samples, - const int max_seq_length, - const double short_seq_prob, - const int seed, - const bool verbose, - const int32_t min_num_sent) { - - if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); - } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); + const int max_seq_length, const double short_seq_prob, + const int seed, const bool verbose, + const int32_t min_num_sent) { + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl( + docs_, sizes_, num_epochs, max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_mapping_impl( + docs_, sizes_, num_epochs, max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, min_num_sent); + } } -template -py::array build_blocks_mapping_impl(const py::array_t& docs_, - const py::array_t& sizes_, - const py::array_t& titles_sizes_, - const int32_t num_epochs, - const uint64_t max_num_samples, - const int32_t max_seq_length, - const int32_t seed, - const bool verbose, - const bool use_one_sent_blocks) { - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - auto titles_sizes = titles_sizes_.unchecked<1>(); +template +py::array build_blocks_mapping_impl( + const py::array_t& docs_, const py::array_t& sizes_, + const py::array_t& titles_sizes_, const int32_t num_epochs, + const uint64_t max_num_samples, const int32_t max_seq_length, + const int32_t seed, const bool verbose, const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); - if (verbose) { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; - } + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 + << endl + << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl + << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl + << std::flush; + cout << " seed: " << seed << endl + << std::flush; + } - // Mapping and its length (1D). - int64_t num_samples = -1; - DocIdx* maps = NULL; + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; - // Acceptable number of sentences per block. - int min_num_sent = 2; - if (use_one_sent_blocks) { - min_num_sent = 1; - } + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - - // Set the flag on second iteration. - second = (iteration == 1); - - // Current map index. - uint64_t map_index = 0; - - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { - if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; - } - break; - } - // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - const auto target_seq_len = max_seq_length - titles_sizes[doc]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } - } - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent >= min_num_sent) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - // If we have enough sentences and no long sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - - // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and there are an acceptable number of sentences left - // and if we have at least the minimum number of sentences. - // or if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent >= min_num_sent) && - (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { - - // Populate the map. - if (second) { - const auto map_index_0 = 4 * map_index; - // Each sample has 4 items: the starting sentence index, ending sentence index, - // the index of the document from which the block comes (used for fetching titles) - // and the unique id of the block (used for creating block indexes) - - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(doc); - maps[map_index_0 + 3] = static_cast(block_id); - } - - // Update indices / counters. - ++map_index; - ++block_id; - prev_start_index = sent_index + 1; - seq_len = 0; - num_sent = 0; - } - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) { - if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration = 0; iteration < 2; ++iteration) { + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { + // assign every block a unique id + int32_t block_id = 0; + + if (map_index >= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl + << std::flush; + } + break; + } + // For each document: + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[4*map_index]; - num_samples = static_cast(map_index); + } } - - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 4 * i; - const auto j0 = 4 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - swap(maps[i0 + 3], maps[j0 + 3]); + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent)) || + (num_remain_sent == 0)) { + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending + // sentence index, the index of the document from which the + // block comes (used for fetching titles) and the unique id of + // the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << endl + << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs + << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs + << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4 * map_index]; + num_samples = static_cast(map_index); } - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); - - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 4}, // shape - {4*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i = (num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } -py::array build_blocks_mapping(const py::array_t& docs_, - const py::array_t& sizes_, - const py::array_t& titles_sizes_, - const int num_epochs, - const uint64_t max_num_samples, - const int max_seq_length, - const int seed, - const bool verbose, - const bool use_one_sent_blocks) { - - if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); - } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); +py::array build_blocks_mapping( + const py::array_t& docs_, const py::array_t& sizes_, + const py::array_t& titles_sizes_, const int num_epochs, + const uint64_t max_num_samples, const int max_seq_length, const int seed, + const bool verbose, const bool use_one_sent_blocks) { + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl( + docs_, sizes_, titles_sizes_, num_epochs, max_num_samples, + max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_blocks_mapping_impl( + docs_, sizes_, titles_sizes_, num_epochs, max_num_samples, + max_seq_length, seed, verbose, use_one_sent_blocks); + } } PYBIND11_MODULE(helpers, m) { - m.def("build_mapping", &build_mapping); - m.def("build_blocks_mapping", &build_blocks_mapping); - m.def("build_sample_idx", &build_sample_idx); - m.def("build_blending_indices", &build_blending_indices); + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); } diff --git a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py index 6dac35ff9d41..220099f9ba32 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py @@ -2,12 +2,11 @@ import random import numpy as np -from torch.utils.data import Dataset - -from megatron import get_tokenizer -from megatron import get_args +from megatron import get_args, get_tokenizer from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset_utils import get_block_samples_mapping +from torch.utils.data import Dataset + def make_attention_mask(source_block, target_block): """ @@ -20,16 +19,17 @@ def make_attention_mask(source_block, target_block): # (source_length, target_length) return mask + def get_ict_dataset(use_titles=True, query_in_block_prob=1): """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) rather than for training, since it is only built with a single epoch sample mapping. """ args = get_args() - block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) - titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) + block_dataset = get_indexed_dataset_(args.data_path, "mmap", True) + titles_dataset = get_indexed_dataset_(args.titles_data_path, "mmap", True) kwargs = dict( - name='full', + name="full", block_dataset=block_dataset, title_dataset=titles_dataset, data_prefix=args.data_path, @@ -39,7 +39,7 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): seed=1, query_in_block_prob=query_in_block_prob, use_titles=use_titles, - use_one_sent_docs=args.use_one_sent_docs + use_one_sent_docs=args.use_one_sent_docs, ) dataset = ICTDataset(**kwargs) return dataset @@ -47,9 +47,22 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" - def __init__(self, name, block_dataset, title_dataset, data_prefix, - num_epochs, max_num_samples, max_seq_length, query_in_block_prob, - seed, use_titles=True, use_one_sent_docs=False, binary_head=False): + + def __init__( + self, + name, + block_dataset, + title_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + query_in_block_prob, + seed, + use_titles=True, + use_one_sent_docs=False, + binary_head=False, + ): self.name = name self.seed = seed self.max_seq_length = max_seq_length @@ -61,8 +74,16 @@ def __init__(self, name, block_dataset, title_dataset, data_prefix, self.use_one_sent_docs = use_one_sent_docs self.samples_mapping = get_block_samples_mapping( - block_dataset, title_dataset, data_prefix, num_epochs, - max_num_samples, max_seq_length, seed, name, use_one_sent_docs) + block_dataset, + title_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + seed, + name, + use_one_sent_docs, + ) self.tokenizer = get_tokenizer() self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_list = self.tokenizer.inv_vocab @@ -99,8 +120,8 @@ def __getitem__(self, idx): # still need to truncate because blocks are concluded when # the sentence lengths have exceeded max_seq_length. - query = query[:self.max_seq_length - 2] - block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] + query = query[: self.max_seq_length - 2] + block = list(itertools.chain(*block))[: self.max_seq_length - title_pad_offset] query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title) @@ -111,13 +132,13 @@ def __getitem__(self, idx): block_data = sample_data.as_array() sample = { - 'query_tokens': query_tokens, - 'query_mask': query_mask, - 'query_pad_mask': query_pad_mask, - 'context_tokens': context_tokens, - 'context_mask': context_mask, - 'context_pad_mask': context_pad_mask, - 'block_data': block_data, + "query_tokens": query_tokens, + "query_mask": query_mask, + "query_pad_mask": query_pad_mask, + "context_tokens": context_tokens, + "context_mask": context_mask, + "context_pad_mask": context_pad_mask, + "block_data": block_data, } return sample @@ -127,7 +148,7 @@ def get_block(self, start_idx, end_idx, doc_idx): block = [self.block_dataset[i] for i in range(start_idx, end_idx)] title = self.title_dataset[int(doc_idx)] - block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] + block = list(itertools.chain(*block))[: self.max_seq_length - (3 + len(title))] block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) return block_tokens, block_pad_mask diff --git a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py index 9a25dc453c24..961a1650bd74 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py @@ -27,17 +27,17 @@ def __best_fitting_dtype(vocab_size=None): def get_available_dataset_impl(): - return ['lazy', 'cached', 'mmap'] + return ["lazy", "cached", "mmap"] def infer_dataset_impl(path): if IndexedDataset.exists(path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: - return 'cached' + return "cached" elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: - return 'mmap' + return "mmap" else: return None else: @@ -47,7 +47,7 @@ def infer_dataset_impl(path): def make_builder(out_file, impl, vocab_size=None): - if impl == 'mmap': + if impl == "mmap": return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) else: return IndexedDatasetBuilder(out_file) @@ -58,20 +58,20 @@ def make_dataset(path, impl, skip_warmup=False): print(f"Dataset does not exist: {path}") print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") return None - if impl == 'infer': + if impl == "infer": impl = infer_dataset_impl(path) - if impl == 'lazy' and IndexedDataset.exists(path): + if impl == "lazy" and IndexedDataset.exists(path): return IndexedDataset(path) - elif impl == 'cached' and IndexedDataset.exists(path): + elif impl == "cached" and IndexedDataset.exists(path): return IndexedCachedDataset(path) - elif impl == 'mmap' and MMapIndexedDataset.exists(path): + elif impl == "mmap" and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path, skip_warmup) print(f"Unknown dataset implementation: {impl}") return None def dataset_exists(path, impl): - if impl == 'mmap': + if impl == "mmap": return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) @@ -98,11 +98,11 @@ def code(dtype): def index_file_path(prefix_path): - return prefix_path + '.idx' + return prefix_path + ".idx" def data_file_path(prefix_path): - return prefix_path + '.bin' + return prefix_path + ".bin" def create_doc_idx(sizes): @@ -115,7 +115,8 @@ def create_doc_idx(sizes): class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" - _HDR_MAGIC = b'TNTIDX\x00\x00' + + _HDR_MAGIC = b"TNTIDX\x00\x00" def __init__(self, path): super().__init__() @@ -124,27 +125,28 @@ def __init__(self, path): self.read_index(path) def read_index(self, path): - with open(index_file_path(path), 'rb') as f: + with open(index_file_path(path), "rb") as f: magic = f.read(8) - assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. ' - 'Make sure that --dataset-impl is configured properly.') + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." + ) version = f.read(8) - assert struct.unpack('= self._len: - raise IndexError('index out of range') + raise IndexError("index out of range") def __del__(self): if self.data_file: @@ -157,7 +159,7 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) @@ -166,7 +168,7 @@ def __getitem__(self, idx): start, stop, step = idx.indices(len(self)) if step != 1: raise ValueError("Slices into indexed_dataset must be contiguous") - sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] size = sum(sizes) a = np.empty(size, dtype=self.dtype) self.data_file.seek(self.data_offsets[start] * self.element_size) @@ -186,15 +188,14 @@ def size(self, index): @staticmethod def exists(path): - return (os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))) + return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) @property def supports_prefetch(self): - return False # avoid prefetching to save memory + return False # avoid prefetching to save memory class IndexedCachedDataset(IndexedDataset): - def __init__(self, path): super().__init__(path) self.cache = None @@ -219,7 +220,7 @@ def prefetch(self, indices): for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] - a = self.cache[ptx:ptx + size] + a = self.cache[ptx : ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size @@ -233,10 +234,10 @@ def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) - tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] - np.copyto(a, self.cache[ptx:ptx + a.size]) + np.copyto(a, self.cache[ptx : ptx + a.size]) return a elif isinstance(idx, slice): # Hack just to make this work, can optimizer later if necessary @@ -250,7 +251,7 @@ class IndexedDatasetBuilder(object): element_sizes = {np.uint8: 1, np.int8: 1, np.int16: 2, np.int32: 4, np.int64: 8, float: 4, np.double: 8} def __init__(self, out_file, dtype=np.int32): - self.out_file = open(out_file, 'wb') + self.out_file = open(out_file, "wb") self.dtype = dtype self.data_offsets = [0] self.dim_offsets = [0] @@ -280,7 +281,7 @@ def merge_file_(self, another_file): for dim_offset in index.dim_offsets[1:]: self.dim_offsets.append(begin + dim_offset) - with open(data_file_path(another_file), 'rb') as f: + with open(data_file_path(another_file), "rb") as f: while True: data = f.read(1024) if data: @@ -290,12 +291,12 @@ def merge_file_(self, another_file): def finalize(self, index_file): self.out_file.close() - index = open(index_file, 'wb') - index.write(b'TNTIDX\x00\x00') - index.write(struct.pack('= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # return True return False @@ -320,7 +320,7 @@ def _clean_text(self, text): output = [] for char in text: cp = ord(char) - if cp == 0 or cp == 0xfffd or _is_control(char): + if cp == 0 or cp == 0xFFFD or _is_control(char): continue if _is_whitespace(char): output.append(" ") @@ -422,8 +422,7 @@ def _is_punctuation(char): # Characters such as "^", "$", and "`" are not in the Unicode # Punctuation class but we treat them as punctuation anyways, for # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): return True cat = unicodedata.category(char) if cat.startswith("P"): diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py index ba832b5cdce9..6c7bfd69f3f9 100644 --- a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py +++ b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py @@ -25,16 +25,15 @@ def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): """Initialize tokenizer.""" if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: - print('> building {} tokenizer ...'.format(tokenizer_type), flush=True) + print("> building {} tokenizer ...".format(tokenizer_type), flush=True) # Select and instantiate the tokenizer. - if tokenizer_type == 'BertWordPieceLowerCase': + if tokenizer_type == "BertWordPieceLowerCase": tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids) - elif tokenizer_type == 'BertWordPieceCase': + elif tokenizer_type == "BertWordPieceCase": tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids) else: - raise NotImplementedError('{} tokenizer is not ' - 'implemented.'.format(tokenizer_type)) + raise NotImplementedError("{} tokenizer is not " "implemented.".format(tokenizer_type)) # Add vocab size. padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size) @@ -55,9 +54,11 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): while (after % multiple) != 0: after += 1 if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: - print(' > padded vocab (size: {}) with {} dummy tokens ' - '(new size: {})'.format(orig_vocab_size, after - orig_vocab_size, after), - flush=True) + print( + " > padded vocab (size: {}) with {} dummy tokens " + "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after), + flush=True, + ) return after @@ -77,46 +78,38 @@ def vocab_size(self): @abstractmethod def vocab(self): """Dictionary from vocab text token to id token.""" - pass @property @abstractmethod def inv_vocab(self): """Dictionary from vocab id token to text token.""" - pass @abstractmethod def tokenize(self, text): pass def detokenize(self, token_ids): - raise NotImplementedError('detokenizer is not implemented for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name)) @property def cls(self): - raise NotImplementedError('CLS is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name)) @property def sep(self): - raise NotImplementedError('SEP is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name)) @property def pad(self): - raise NotImplementedError('PAD is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name)) @property def eod(self): - raise NotImplementedError('EOD is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name)) @property def mask(self): - raise NotImplementedError('MASK is not provided for {} ' - 'tokenizer'.format(self.name)) + raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name)) class _BertWordPieceTokenizer(AbstractTokenizer): @@ -124,24 +117,24 @@ class _BertWordPieceTokenizer(AbstractTokenizer): def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): if lower_case: - name = 'BERT Lower Case' + name = "BERT Lower Case" else: - name = 'BERT Upper Case' + name = "BERT Upper Case" super().__init__(name) self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=lower_case) - self.cls_id = self.tokenizer.vocab['[CLS]'] - self.sep_id = self.tokenizer.vocab['[SEP]'] - self.pad_id = self.tokenizer.vocab['[PAD]'] - self.mask_id = self.tokenizer.vocab['[MASK]'] + self.cls_id = self.tokenizer.vocab["[CLS]"] + self.sep_id = self.tokenizer.vocab["[SEP]"] + self.pad_id = self.tokenizer.vocab["[PAD]"] + self.mask_id = self.tokenizer.vocab["[MASK]"] self._additional_special_tokens = [] # (dsachan) Add BOS and EOS tokens - SPECIAL_TOKENS = {'eos_token': '[EOS]', 'bos_token': '[BOS]'} - self._bos_token = '[BOS]' + SPECIAL_TOKENS = {"eos_token": "[EOS]", "bos_token": "[BOS]"} + self._bos_token = "[BOS]" self.add_token(self._bos_token) self._bos_token_id = self.vocab.get(self._bos_token) - self._eos_token = '[EOS]' + self._eos_token = "[EOS]" self.add_token(self._eos_token) self._eos_token_id = self.vocab.get(self._eos_token) @@ -185,7 +178,7 @@ def decode(self, ids): def decode_token_ids(self, token_ids): tokens = self.tokenizer.convert_ids_to_tokens(token_ids) - exclude_list = ['[PAD]', '[CLS]'] + exclude_list = ["[PAD]", "[CLS]"] non_pads = [t for t in tokens if t not in exclude_list] result = "" @@ -215,32 +208,32 @@ def mask(self): @property def bos_token(self): - """ Beginning of sentence token id """ + """Beginning of sentence token id""" return self._bos_token @property def eos_token(self): - """ End of sentence token id """ + """End of sentence token id""" return self._eos_token @property def additional_special_tokens(self): - """ All the additional special tokens you may want to use (list of strings).""" + """All the additional special tokens you may want to use (list of strings).""" return self._additional_special_tokens @property def bos_token_id(self): - """ Id of the beginning of sentence token in the vocabulary.""" + """Id of the beginning of sentence token in the vocabulary.""" return self._bos_token_id @property def eos_token_id(self): - """ Id of the end of sentence token in the vocabulary.""" + """Id of the end of sentence token in the vocabulary.""" return self._eos_token_id @property def additional_special_tokens_ids(self): - """ Ids of all the additional special tokens in the vocabulary (list of integers).""" + """Ids of all the additional special tokens in the vocabulary (list of integers).""" return [self.vocab.get(token) for token in self._additional_special_tokens] @additional_special_tokens.setter diff --git a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py index b3f2487a438b..869ff720f4b0 100644 --- a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py +++ b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py @@ -1,17 +1,12 @@ import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.logging import get_dist_logger - -from .cross_entropy import vocab_cross_entropy class BertLoss(nn.Module): - def forward(self, lm_loss, sop_logits, loss_mask, sentence_order): lm_loss_ = lm_loss.float() loss_mask = loss_mask.float() diff --git a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py index ed15c6ea8054..b5d9ea919261 100644 --- a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py +++ b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py @@ -1,11 +1,8 @@ import torch from torch.cuda.amp import custom_bwd, custom_fwd -from colossalai.legacy.context.parallel_mode import ParallelMode - class _VocabCrossEntropy(torch.autograd.Function): - @staticmethod @custom_fwd def forward(ctx, vocab_parallel_logits, target): @@ -59,7 +56,7 @@ def backward(ctx, grad_output): # Add the gradient from matching classes. arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) diff --git a/examples/tutorial/sequence_parallel/loss_func/utils.py b/examples/tutorial/sequence_parallel/loss_func/utils.py index a3d92f294326..35fa73896c46 100644 --- a/examples/tutorial/sequence_parallel/loss_func/utils.py +++ b/examples/tutorial/sequence_parallel/loss_func/utils.py @@ -1,11 +1,9 @@ - import torch def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, '{} is not divisible by {}'.format( - numerator, denominator) + assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) def divide(numerator, denominator): @@ -15,8 +13,7 @@ def divide(numerator, denominator): return numerator // denominator -def split_tensor_along_last_dim(tensor, num_partitions, - contiguous_split_chunks=False): +def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): """Split a tensor along its last dimension. Arguments: tensor: input tensor. @@ -38,12 +35,11 @@ def split_tensor_along_last_dim(tensor, num_partitions, class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last)""" + first and last index of the vocabulary belonging to the `rank` + partition: Note that indices in [fist, last)""" @staticmethod - def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, - rank, world_size): + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f, index_l @@ -51,5 +47,4 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, @staticmethod def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): per_partition_vocab_size = divide(global_vocab_size, world_size) - return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size) diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py index 8d95679ff76d..866d0d54583b 100644 --- a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py +++ b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py @@ -21,16 +21,17 @@ class AnnealingLR(object): """Anneals the learning rate.""" - def __init__(self, - optimizer, - max_lr, - min_lr, - warmup_steps, - decay_steps, - decay_style, - use_checkpoint_lr_scheduler=True, - override_lr_scheduler=False): - + def __init__( + self, + optimizer, + max_lr, + min_lr, + warmup_steps, + decay_steps, + decay_style, + use_checkpoint_lr_scheduler=True, + override_lr_scheduler=False, + ): # Class values. self.optimizer = optimizer @@ -50,23 +51,21 @@ def __init__(self, self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler if self.override_lr_scheduler: - assert not self.use_checkpoint_lr_scheduler, 'both override and '\ - 'use-checkpoint are set.' + assert not self.use_checkpoint_lr_scheduler, "both override and " "use-checkpoint are set." # Set the learning rate self.step(0) def get_lr(self): """Learning rate decay functions from: - https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: - return self.max_lr * float(self.num_steps) / \ - float(self.warmup_steps) + return self.max_lr * float(self.num_steps) / float(self.warmup_steps) # If the learning rate is constant, just return the initial value. - if self.decay_style == 'constant': + if self.decay_style == "constant": return self.max_lr # For any steps larger than `self.decay_steps`, use `self.min_lr`. @@ -81,13 +80,12 @@ def get_lr(self): assert decay_ratio <= 1.0 delta_lr = self.max_lr - self.min_lr - if self.decay_style == 'linear': - coeff = (1.0 - decay_ratio) - elif self.decay_style == 'cosine': + if self.decay_style == "linear": + coeff = 1.0 - decay_ratio + elif self.decay_style == "cosine": coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) else: - raise Exception('{} decay style is not supported.'.format( - self.decay_style)) + raise Exception("{} decay style is not supported.".format(self.decay_style)) return self.min_lr + coeff * delta_lr @@ -96,16 +94,16 @@ def step(self, increment=1): self.num_steps += increment new_lr = self.get_lr() for group in self.optimizer.param_groups: - group['lr'] = new_lr + group["lr"] = new_lr def state_dict(self): state_dict = { - 'max_lr': self.max_lr, - 'warmup_steps': self.warmup_steps, - 'num_steps': self.num_steps, - 'decay_style': self.decay_style, - 'decay_steps': self.decay_steps, - 'min_lr': self.min_lr + "max_lr": self.max_lr, + "warmup_steps": self.warmup_steps, + "num_steps": self.num_steps, + "decay_style": self.decay_style, + "decay_steps": self.decay_steps, + "min_lr": self.min_lr, } return state_dict @@ -116,43 +114,35 @@ def _check_and_set(self, cls_value, sd_value, name): return cls_value if not self.use_checkpoint_lr_scheduler: - assert cls_value == sd_value, \ - f'AnnealingLR: class input value {cls_value} and checkpoint' \ - f'value {sd_value} for {name} do not match' + assert cls_value == sd_value, ( + f"AnnealingLR: class input value {cls_value} and checkpoint" f"value {sd_value} for {name} do not match" + ) return sd_value def load_state_dict(self, sd): - - if 'start_lr' in sd: - max_lr_ = sd['start_lr'] + if "start_lr" in sd: + max_lr_ = sd["start_lr"] else: - max_lr_ = sd['max_lr'] - self.max_lr = self._check_and_set(self.max_lr, max_lr_, - 'learning rate') + max_lr_ = sd["max_lr"] + self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate") - self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], - 'minimum learning rate') + self.min_lr = self._check_and_set(self.min_lr, sd["min_lr"], "minimum learning rate") - if 'warmup_iter' in sd: - warmup_steps_ = sd['warmup_iter'] + if "warmup_iter" in sd: + warmup_steps_ = sd["warmup_iter"] else: - warmup_steps_ = sd['warmup_steps'] - self.warmup_steps = self._check_and_set(self.warmup_steps, - warmup_steps_, - 'warmup iterations') + warmup_steps_ = sd["warmup_steps"] + self.warmup_steps = self._check_and_set(self.warmup_steps, warmup_steps_, "warmup iterations") - if 'end_iter' in sd: - decay_steps_ = sd['end_iter'] + if "end_iter" in sd: + decay_steps_ = sd["end_iter"] else: - decay_steps_ = sd['decay_steps'] - self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, - 'total number of iterations') - self.decay_style = self._check_and_set(self.decay_style, - sd['decay_style'], - 'decay style') - - if 'num_iters' in sd: - num_steps = sd['num_iters'] + decay_steps_ = sd["decay_steps"] + self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, "total number of iterations") + self.decay_style = self._check_and_set(self.decay_style, sd["decay_style"], "decay style") + + if "num_iters" in sd: + num_steps = sd["num_iters"] else: - num_steps = sd['num_steps'] + num_steps = sd["num_steps"] self.step(increment=num_steps) diff --git a/examples/tutorial/sequence_parallel/model/__init__.py b/examples/tutorial/sequence_parallel/model/__init__.py index 139597f9cb07..e69de29bb2d1 100644 --- a/examples/tutorial/sequence_parallel/model/__init__.py +++ b/examples/tutorial/sequence_parallel/model/__init__.py @@ -1,2 +0,0 @@ - - diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py index 4ba64bbe2b9f..7b0e93d958ca 100644 --- a/examples/tutorial/sequence_parallel/model/bert.py +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -16,7 +16,6 @@ class BertForPretrain(nn.Module): - def __init__( self, vocab_size, @@ -34,7 +33,9 @@ def __init__( ): super().__init__() self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) - assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + assert ( + max_sequence_length % self.seq_parallel_size == 0 + ), "sequence length is not divisible by the sequence parallel size" self.sub_seq_length = max_sequence_length // self.seq_parallel_size self.init_std = init_std self.num_layers = num_layers @@ -43,28 +44,32 @@ def __init__( num_tokentypes = 0 self.preprocessor = PreProcessor(self.sub_seq_length) - self.embedding = Embedding(hidden_size=hidden_size, - vocab_size=vocab_size, - max_sequence_length=max_sequence_length, - embedding_dropout_prob=dropout_prob, - num_tokentypes=num_tokentypes) + self.embedding = Embedding( + hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes, + ) self.bert_layers = nn.ModuleList() for i in range(num_layers): - bert_layer = BertLayer(layer_number=i + 1, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - attention_dropout=dropout_prob, - mlp_ratio=mlp_ratio, - hidden_dropout=dropout_prob, - convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - is_naive_fp16=is_naive_fp16) + bert_layer = BertLayer( + layer_number=i + 1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16, + ) self.bert_layers.append(bert_layer) self.layer_norm = LayerNorm(hidden_size) - self.head = BertDualHead(hidden_size, - self.embedding.word_embedding_weight.size(0), - add_binary_head=add_binary_head) + self.head = BertDualHead( + hidden_size, self.embedding.word_embedding_weight.size(0), add_binary_head=add_binary_head + ) self.reset_parameters() def _init_normal(self, tensor): @@ -122,27 +127,30 @@ def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels): class PipelineBertForPretrain(nn.Module): - - def __init__(self, - vocab_size, - hidden_size, - max_sequence_length, - num_attention_heads, - num_layers, - add_binary_head, - is_naive_fp16, - num_tokentypes=2, - dropout_prob=0.1, - mlp_ratio=4, - init_std=0.02, - convert_fp16_to_fp32_in_softmax=False, - first_stage=True, - last_stage=True, - start_idx=None, - end_idx=None): + def __init__( + self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + first_stage=True, + last_stage=True, + start_idx=None, + end_idx=None, + ): super().__init__() self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) - assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' + assert ( + max_sequence_length % self.seq_parallel_size == 0 + ), "sequence length is not divisible by the sequence parallel size" self.sub_seq_length = max_sequence_length // self.seq_parallel_size self.init_std = init_std self.num_layers = num_layers @@ -156,11 +164,13 @@ def __init__(self, self.preprocessor = PreProcessor(self.sub_seq_length) if self.first_stage: - self.embedding = Embedding(hidden_size=hidden_size, - vocab_size=vocab_size, - max_sequence_length=max_sequence_length, - embedding_dropout_prob=dropout_prob, - num_tokentypes=num_tokentypes) + self.embedding = Embedding( + hidden_size=hidden_size, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + embedding_dropout_prob=dropout_prob, + num_tokentypes=num_tokentypes, + ) # transformer layers self.bert_layers = nn.ModuleList() @@ -170,14 +180,16 @@ def __init__(self, end_idx = num_layers for i in range(start_idx, end_idx): - bert_layer = BertLayer(layer_number=i + 1, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - attention_dropout=dropout_prob, - mlp_ratio=mlp_ratio, - hidden_dropout=dropout_prob, - convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - is_naive_fp16=is_naive_fp16) + bert_layer = BertLayer( + layer_number=i + 1, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + attention_dropout=dropout_prob, + mlp_ratio=mlp_ratio, + hidden_dropout=dropout_prob, + convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, + is_naive_fp16=is_naive_fp16, + ) self.bert_layers.append(bert_layer) if self.last_stage: @@ -256,7 +268,7 @@ def _filter_kwargs(func, kwargs): return {k: v for k, v in kwargs.items() if k in sig.parameters} -def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs): +def build_pipeline_bert(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): logger = get_dist_logger() pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) @@ -265,12 +277,12 @@ def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **k parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank] models = [] for start, end in parts: - kwargs['num_layers'] = num_layers - kwargs['start_idx'] = start - kwargs['end_idx'] = end - kwargs['first_stage'] = start == 0 - kwargs['last_stage'] = end == num_layers - logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers') + kwargs["num_layers"] = num_layers + kwargs["start_idx"] = start + kwargs["end_idx"] = end + kwargs["first_stage"] = start == 0 + kwargs["last_stage"] = end == num_layers + logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers") chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device) if start == 0: wrapper.register_module(chunk.embedding.word_embeddings) diff --git a/examples/tutorial/sequence_parallel/model/layers/__init__.py b/examples/tutorial/sequence_parallel/model/layers/__init__.py index 3a8823caa81b..58495c516239 100644 --- a/examples/tutorial/sequence_parallel/model/layers/__init__.py +++ b/examples/tutorial/sequence_parallel/model/layers/__init__.py @@ -1,4 +1,4 @@ -from .embedding import VocabEmbedding, Embedding from .bert_layer import BertLayer +from .embedding import Embedding, VocabEmbedding from .head import BertDualHead from .preprocess import PreProcessor diff --git a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py index 56ba511d8274..1ef16ee6ad79 100644 --- a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py +++ b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py @@ -20,18 +20,20 @@ class BertLayer(nn.Module): output of the same size. """ - def __init__(self, - layer_number, - hidden_size, - num_attention_heads, - attention_dropout, - mlp_ratio, - hidden_dropout, - is_naive_fp16, - apply_residual_connection_post_layernorm=False, - fp32_residual_connection=False, - bias_dropout_fusion: bool = True, - convert_fp16_to_fp32_in_softmax: bool = False): + def __init__( + self, + layer_number, + hidden_size, + num_attention_heads, + attention_dropout, + mlp_ratio, + hidden_dropout, + is_naive_fp16, + apply_residual_connection_post_layernorm=False, + fp32_residual_connection=False, + bias_dropout_fusion: bool = True, + convert_fp16_to_fp32_in_softmax: bool = False, + ): super().__init__() self.layer_number = layer_number @@ -50,7 +52,8 @@ def __init__(self, layer_number=layer_number, apply_query_key_layer_scaling=True, convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - fp16=is_naive_fp16) + fp16=is_naive_fp16, + ) self.hidden_dropout = hidden_dropout self.bias_dropout_fusion = bias_dropout_fusion @@ -90,8 +93,9 @@ def forward(self, hidden_states, attention_mask): # re-enable torch grad to enable fused optimization. with torch.enable_grad(): - layernorm_input = bias_dropout_add_func(attention_output, attention_bias.expand_as(residual), residual, - self.hidden_dropout) + layernorm_input = bias_dropout_add_func( + attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout + ) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) diff --git a/examples/tutorial/sequence_parallel/model/layers/dropout.py b/examples/tutorial/sequence_parallel/model/layers/dropout.py index 0e99105b8f7e..18eae0d63cd1 100644 --- a/examples/tutorial/sequence_parallel/model/layers/dropout.py +++ b/examples/tutorial/sequence_parallel/model/layers/dropout.py @@ -1,5 +1,6 @@ import torch + def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor out = torch.nn.functional.dropout(x + bias, p=prob, training=training) @@ -10,4 +11,5 @@ def bias_dropout_add(x, bias, residual, prob, training): def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) - return _bias_dropout_add \ No newline at end of file + + return _bias_dropout_add diff --git a/examples/tutorial/sequence_parallel/model/layers/embedding.py b/examples/tutorial/sequence_parallel/model/layers/embedding.py index 0700d960d845..03183c55243f 100644 --- a/examples/tutorial/sequence_parallel/model/layers/embedding.py +++ b/examples/tutorial/sequence_parallel/model/layers/embedding.py @@ -5,7 +5,6 @@ class VocabEmbedding(torch.nn.Module): - def __init__(self, num_embeddings, embedding_dim): super(VocabEmbedding, self).__init__() # Keep the input dimensions. @@ -13,26 +12,29 @@ def __init__(self, num_embeddings, embedding_dim): self.embedding_dim = embedding_dim self.padding_idx = None self.max_norm = None - self.norm_type = 2. + self.norm_type = 2.0 self.scale_grad_by_freq = False self.sparse = False self._weight = None # Allocate weights and initialize. - self.weight = nn.Parameter(torch.empty( - self.num_embeddings, self.embedding_dim)) + self.weight = nn.Parameter(torch.empty(self.num_embeddings, self.embedding_dim)) init.xavier_uniform_(self.weight) def forward(self, hidden_state): - output = F.embedding(hidden_state, self.weight, - self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, - self.sparse) + output = F.embedding( + hidden_state, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return output def __repr__(self): - return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \ - f'embedding_dim={self.embedding_dim})' + return f"VocabEmbedding(num_embeddings={self.num_embeddings}, " f"embedding_dim={self.embedding_dim})" class Embedding(nn.Module): @@ -48,12 +50,7 @@ class Embedding(nn.Module): will ignore this embedding """ - def __init__(self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - num_tokentypes): + def __init__(self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes): super(Embedding, self).__init__() self.hidden_size = hidden_size @@ -62,16 +59,14 @@ def __init__(self, self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size) # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) + self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) # Token type embedding. # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, - self.hidden_size) + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) else: self.tokentype_embeddings = None diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py index 9e25157e1b40..75afeee60ad4 100644 --- a/examples/tutorial/sequence_parallel/model/layers/head.py +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -3,12 +3,10 @@ import torch.nn.functional as F from loss_func.cross_entropy import vocab_cross_entropy -import colossalai from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc -from .embedding import VocabEmbedding from .linear import Linear from .pooler import Pooler @@ -26,7 +24,6 @@ def __init__( vocab_size, hidden_size, ): - super(BertLMHead, self).__init__() self.bias = torch.nn.Parameter(torch.zeros(vocab_size)) @@ -46,7 +43,6 @@ def forward(self, hidden_states, word_embeddings_weight, lm_labels): class BertBinaryHead(nn.Module): - def __init__(self, hidden_size): super().__init__() self.pooler = Pooler(hidden_size) @@ -62,7 +58,6 @@ def forward(self, hidden_states): class BertDualHead(nn.Module): - def __init__(self, hidden_size, vocab_size, add_binary_head): super().__init__() self.lm_head = BertLMHead(vocab_size, hidden_size) diff --git a/examples/tutorial/sequence_parallel/model/layers/init_method.py b/examples/tutorial/sequence_parallel/model/layers/init_method.py index 1b409dfe4054..22d12a504fab 100644 --- a/examples/tutorial/sequence_parallel/model/layers/init_method.py +++ b/examples/tutorial/sequence_parallel/model/layers/init_method.py @@ -1,6 +1,8 @@ -import torch import math +import torch + + def init_normal(tensor, sigma): """Init method based on N(0, sigma).""" torch.nn.init.normal_(tensor, mean=0.0, std=sigma) diff --git a/examples/tutorial/sequence_parallel/model/layers/linear.py b/examples/tutorial/sequence_parallel/model/layers/linear.py index 5ae7d671e2bf..5592f6e8c209 100644 --- a/examples/tutorial/sequence_parallel/model/layers/linear.py +++ b/examples/tutorial/sequence_parallel/model/layers/linear.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn -from torch.nn import Parameter import torch.nn.functional as F import torch.nn.init as init +from torch.nn import Parameter class Linear(nn.Module): @@ -24,11 +24,7 @@ class Linear(nn.Module): adding bias but instead return it. """ - def __init__(self, - input_size, - output_size, - bias=True, - skip_bias_add=False): + def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): super(Linear, self).__init__() # Keep input parameters @@ -36,9 +32,12 @@ def __init__(self, self.output_size = output_size self.skip_bias_add = skip_bias_add - self.weight = Parameter(torch.empty(self.output_size, - self.input_size, - )) + self.weight = Parameter( + torch.empty( + self.output_size, + self.input_size, + ) + ) init.normal_(self.weight) if bias: self.bias = Parameter(torch.empty(self.output_size)) @@ -46,7 +45,7 @@ def __init__(self, with torch.no_grad(): self.bias.zero_() else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, input_): # Matrix multiply. @@ -59,5 +58,7 @@ def forward(self, input_): return output def __repr__(self): - return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ - f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' + return ( + f"Linear(in_features={self.input_size}, out_features={self.output_size}, " + + f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})" + ) diff --git a/examples/tutorial/sequence_parallel/model/layers/mlp.py b/examples/tutorial/sequence_parallel/model/layers/mlp.py index a255de813d13..54a695fda402 100644 --- a/examples/tutorial/sequence_parallel/model/layers/mlp.py +++ b/examples/tutorial/sequence_parallel/model/layers/mlp.py @@ -1,10 +1,10 @@ -import torch import torch.nn as nn import torch.nn.functional as F -from .linear import Linear from colossalai.kernel.jit import bias_gelu_impl +from .linear import Linear + class TransformerMLP(nn.Module): """MLP. @@ -18,19 +18,13 @@ def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True): super(TransformerMLP, self).__init__() # Project to 4h. - self.dense_h_to_4h = Linear( - hidden_size, - int(hidden_size*mlp_ratio), - skip_bias_add=True) + self.dense_h_to_4h = Linear(hidden_size, int(hidden_size * mlp_ratio), skip_bias_add=True) self.bias_gelu_fusion = fuse_gelu self.activation_func = F.gelu # Project back to h. - self.dense_4h_to_h = Linear( - int(hidden_size*mlp_ratio), - hidden_size, - skip_bias_add=True) + self.dense_4h_to_h = Linear(int(hidden_size * mlp_ratio), hidden_size, skip_bias_add=True) def forward(self, hidden_states): # hidden states should be in the shape of [s, b, h] @@ -39,11 +33,9 @@ def forward(self, hidden_states): intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) if self.bias_gelu_fusion: - intermediate_parallel = \ - bias_gelu_impl(intermediate_parallel, bias_parallel) + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) else: - intermediate_parallel = \ - self.activation_func(intermediate_parallel + bias_parallel) + intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) diff --git a/examples/tutorial/sequence_parallel/model/layers/pooler.py b/examples/tutorial/sequence_parallel/model/layers/pooler.py index 282ed114790b..c3397787aecf 100644 --- a/examples/tutorial/sequence_parallel/model/layers/pooler.py +++ b/examples/tutorial/sequence_parallel/model/layers/pooler.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn + from .linear import Linear diff --git a/examples/tutorial/sequence_parallel/model/layers/preprocess.py b/examples/tutorial/sequence_parallel/model/layers/preprocess.py index dd66bfe13585..55dd20e1e948 100644 --- a/examples/tutorial/sequence_parallel/model/layers/preprocess.py +++ b/examples/tutorial/sequence_parallel/model/layers/preprocess.py @@ -6,7 +6,6 @@ class PreProcessor(nn.Module): - def __init__(self, sub_seq_length): super().__init__() self.sub_seq_length = sub_seq_length @@ -15,10 +14,9 @@ def bert_position_ids(self, token_ids): # Create position ids seq_length = token_ids.size(1) local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) - position_ids = torch.arange(seq_length * local_rank, - seq_length * (local_rank + 1), - dtype=torch.long, - device=token_ids.device) + position_ids = torch.arange( + seq_length * local_rank, seq_length * (local_rank + 1), dtype=torch.long, device=token_ids.device + ) position_ids = position_ids.unsqueeze(0).expand_as(token_ids) return position_ids @@ -42,7 +40,7 @@ def bert_extended_attention_mask(self, attention_mask): extended_attention_mask = attention_mask_bss.unsqueeze(1) # Convert attention mask to binary: - extended_attention_mask = (extended_attention_mask < 0.5) + extended_attention_mask = extended_attention_mask < 0.5 return extended_attention_mask diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index b8b89cda5525..e9ceb8d70cb8 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -12,7 +12,6 @@ from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc -from colossalai.legacy.engine.schedule import PipelineSchedule from colossalai.legacy.utils import is_using_pp from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import FusedAdam @@ -31,7 +30,7 @@ def process_batch_data(batch_data): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data") + parser.add_argument("-s", "--synthetic", action="store_true", help="whether use synthetic data") return parser.parse_args() @@ -48,37 +47,39 @@ def pipeline_data_process_func(stage_output, micro_batch_data): def main(): # initialize - args = parse_args() - colossalai.launch_from_torch(config='./config.py', seed=1234, backend='nccl') + parse_args() + colossalai.launch_from_torch(config="./config.py", seed=1234, backend="nccl") logger = get_dist_logger() # build synthetic dataloader BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA) VOCAB_SIZE = 30528 - trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, - vocab_size=VOCAB_SIZE, - seq_length=gpc.config.SEQ_LENGTH) - validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS, - vocab_size=VOCAB_SIZE, - seq_length=gpc.config.SEQ_LENGTH) + trainloader = DummyDataloader( + batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH + ) + validloader = DummyDataloader( + batch_size=BATCH_SIZE_PER_GPUS, vocab_size=VOCAB_SIZE, seq_length=gpc.config.SEQ_LENGTH + ) logger.info("Dataloaders are built", ranks=[0]) # build model - if hasattr(gpc.config, 'fp16') and gpc.config.fp16.get('mode') == AMP_TYPE.NAIVE: + if hasattr(gpc.config, "fp16") and gpc.config.fp16.get("mode") == AMP_TYPE.NAIVE: is_naive_fp16 = True else: is_naive_fp16 = False use_pipeline = is_using_pp() - kwargs = dict(vocab_size=VOCAB_SIZE, - hidden_size=gpc.config.HIDDEN_SIZE, - max_sequence_length=gpc.config.SEQ_LENGTH, - num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, - convert_fp16_to_fp32_in_softmax=True, - is_naive_fp16=is_naive_fp16, - add_binary_head=gpc.config.ADD_BINARY_HEAD) + kwargs = dict( + vocab_size=VOCAB_SIZE, + hidden_size=gpc.config.HIDDEN_SIZE, + max_sequence_length=gpc.config.SEQ_LENGTH, + num_attention_heads=gpc.config.NUM_ATTENTION_HEADS, + convert_fp16_to_fp32_in_softmax=True, + is_naive_fp16=is_naive_fp16, + add_binary_head=gpc.config.ADD_BINARY_HEAD, + ) if use_pipeline: model = build_pipeline_bert(num_layers=gpc.config.DEPTH, num_chunks=1, **kwargs) @@ -99,35 +100,39 @@ def main(): logger.info("Criterion is built", ranks=[0]) # layernorm and bias has no weight decay - weight_decay_params = {'params': []} - no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + weight_decay_params = {"params": []} + no_weight_decay_params = {"params": [], "weight_decay": 0.0} for module_ in model.modules(): if isinstance(module_, LayerNorm): - no_weight_decay_params['params'].extend([p for p in list(module_._parameters.values()) if p is not None]) + no_weight_decay_params["params"].extend([p for p in list(module_._parameters.values()) if p is not None]) else: - weight_decay_params['params'].extend( - [p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias']) - no_weight_decay_params['params'].extend( - [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias']) + weight_decay_params["params"].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n != "bias"] + ) + no_weight_decay_params["params"].extend( + [p for n, p in list(module_._parameters.items()) if p is not None and n == "bias"] + ) logger.info( f"without weight decay param: {len(no_weight_decay_params['params'])}, with weight decay param: {len(weight_decay_params['params'])}" ) # optimizer - optimizer = FusedAdam((weight_decay_params, no_weight_decay_params), - lr=gpc.config.LR, - weight_decay=gpc.config.WEIGHT_DECAY) + optimizer = FusedAdam( + (weight_decay_params, no_weight_decay_params), lr=gpc.config.LR, weight_decay=gpc.config.WEIGHT_DECAY + ) logger.info("Optimizer is built", ranks=[0]) # lr scheduler # follow Megatron-LM setting warmup_steps = int(gpc.config.DECAY_ITERS * gpc.config.WARMUP_FRACTION) - lr_scheduler = AnnealingLR(optimizer=optimizer, - max_lr=gpc.config.LR, - min_lr=gpc.config.MIN_LR, - warmup_steps=warmup_steps, - decay_steps=gpc.config.DECAY_ITERS, - decay_style='linear') + lr_scheduler = AnnealingLR( + optimizer=optimizer, + max_lr=gpc.config.LR, + min_lr=gpc.config.MIN_LR, + warmup_steps=warmup_steps, + decay_steps=gpc.config.DECAY_ITERS, + decay_style="linear", + ) logger.info(f"LR Scheduler is built with {warmup_steps} warmup steps and {gpc.config.DECAY_ITERS} decay steps") # # init @@ -135,7 +140,6 @@ def main(): # build timer timer = MultiTimer() - skip_iters = 0 # build loss tracker accumulated_train_loss = torch.zeros(1, dtype=torch.float32).cuda() @@ -150,7 +154,7 @@ def main(): logger.info("start training") for step in range(1, gpc.config.TRAIN_ITERS + 1): - timer.start('train-iterations') + timer.start("train-iterations") engine.train() if use_pipeline: engine.zero_grad() @@ -158,13 +162,14 @@ def main(): engine.step() else: tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( - trainloader) + trainloader + ) engine.zero_grad() lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) train_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) engine.backward(train_loss) engine.step() - timer.stop('train-iterations', keep_in_history=True) + timer.stop("train-iterations", keep_in_history=True) if not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE): accumulated_train_loss += train_loss @@ -177,12 +182,18 @@ def main(): for j in range(gpc.config.EVAL_ITERS): with torch.no_grad(): if use_pipeline: - _, _, eval_loss = engine.execute_schedule(valid_data_iter, - forward_only=True, - return_output_label=False) + _, _, eval_loss = engine.execute_schedule( + valid_data_iter, forward_only=True, return_output_label=False + ) else: - tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch_for_sequence_parallel( - validloader) + ( + tokens, + types, + sentence_order, + loss_mask, + lm_labels, + padding_mask, + ) = get_batch_for_sequence_parallel(validloader) lm_loss, sop_output = engine(tokens, padding_mask, types, lm_labels) eval_loss = engine.criterion(lm_loss, sop_output, loss_mask, sentence_order) @@ -196,18 +207,22 @@ def main(): timer_string = [] for n, t in timer: timer_string.append(f"{n}: {t.get_history_mean()*1000:.5f}") - timer_string = ' | '.join(timer_string) - lr = list(engine.optimizer.param_groups)[0]['lr'] + timer_string = " | ".join(timer_string) + lr = list(engine.optimizer.param_groups)[0]["lr"] loss_scale = engine.optimizer.optim.loss_scale.item() if gpc.is_initialized(ParallelMode.PIPELINE): ranks = [gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1]] else: ranks = [0] - logger.info(f'Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} ' + - f'| Eval Loss: {accumulated_eval_loss.item():.5g} ' + f'| Loss Scale: {loss_scale}' + - f"| Learning rate: {lr} | " + timer_string, - ranks=ranks) + logger.info( + f"Step {step} / {gpc.config.TRAIN_ITERS} | Train Loss: {accumulated_train_loss.item():.5g} " + + f"| Eval Loss: {accumulated_eval_loss.item():.5g} " + + f"| Loss Scale: {loss_scale}" + + f"| Learning rate: {lr} | " + + timer_string, + ranks=ranks, + ) for n, t in timer: t.reset() @@ -215,5 +230,5 @@ def main(): accumulated_train_loss.zero_() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 5ae7223b8c69..808559ec9c2d 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -7,17 +7,26 @@ from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder ALL_OPS = { - 'cpu_adam': CPUAdamBuilder, - 'fused_optim': FusedOptimBuilder, - 'moe': MOEBuilder, - 'multi_head_attn': MultiHeadAttnBuilder, - 'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder, - 'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder, - 'layernorm': LayerNormBuilder, + "cpu_adam": CPUAdamBuilder, + "fused_optim": FusedOptimBuilder, + "moe": MOEBuilder, + "multi_head_attn": MultiHeadAttnBuilder, + "scaled_masked_softmax": ScaledMaskedSoftmaxBuilder, + "scaled_upper_triangle_masked_softmax": ScaledUpperTrainglemaskedSoftmaxBuilder, + "layernorm": LayerNormBuilder, } __all__ = [ - 'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder', - 'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder', - 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder' + "ALL_OPS", + "CPUAdamBuilder", + "FusedOptimBuilder", + "MultiHeadAttnBuilder", + "ScaledMaskedSoftmaxBuilder", + "ScaledUpperTrainglemaskedSoftmaxBuilder", + "MOEBuilder", + "MultiTensorSGDBuilder", + "MultiTensorAdamBuilder", + "MultiTensorLambBuilder", + "MultiTensorScaleBuilder", + "MultiTensorL2NormBuilder", ] diff --git a/op_builder/builder.py b/op_builder/builder.py index 8396235e5cfe..75823ef105c7 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -24,13 +24,14 @@ class Builder(ABC): def __init__(self, name: str, prebuilt_import_path: str): self.name = name self.prebuilt_import_path = prebuilt_import_path - self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] # we store the op as an attribute to avoid repeated building and loading self.cached_op_module = None - assert prebuilt_import_path.startswith('colossalai._C'), \ - f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}' + assert prebuilt_import_path.startswith( + "colossalai._C" + ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" def relative_to_abs_path(self, code_path: str) -> str: """ @@ -46,10 +47,10 @@ def relative_to_abs_path(self, code_path: str) -> str: # this symlink will be replaced with actual files if we install via pypi # thus we cannot tell the colossalai root directory by checking whether the op_builder # is a symlink, we can only tell whether it is inside or outside colossalai - if str(op_builder_module_path).endswith('colossalai/kernel/op_builder'): + if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): root_path = op_builder_module_path.parent.parent else: - root_path = op_builder_module_path.parent.joinpath('colossalai') + root_path = op_builder_module_path.parent.joinpath("colossalai") code_abs_path = root_path.joinpath(code_path) return str(code_abs_path) @@ -59,13 +60,14 @@ def get_cuda_home_include(self): return include path inside the cuda home. """ from torch.utils.cpp_extension import CUDA_HOME + if CUDA_HOME is None: raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") cuda_include = os.path.join(CUDA_HOME, "include") return cuda_include def csrc_abs_path(self, path): - return os.path.join(self.relative_to_abs_path('kernel/cuda_native/csrc'), path) + return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) # functions must be overrided begin @abstractmethod @@ -80,27 +82,24 @@ def include_dirs(self) -> List[str]: """ This function should return a list of include files for extensions. """ - pass @abstractmethod def cxx_flags(self) -> List[str]: """ This function should return a list of cxx compilation flags for extensions. """ - pass @abstractmethod def nvcc_flags(self) -> List[str]: """ This function should return a list of nvcc compilation flags for extensions. """ - pass # functions must be overrided over def strip_empty_entries(self, args): - ''' + """ Drop any empty strings from the list of compile and link flags - ''' + """ return [x for x in args if len(x) > 0] def import_op(self): @@ -114,8 +113,8 @@ def check_runtime_build_environment(self): Check whether the system environment is ready for extension compilation. """ try: - import torch from torch.utils.cpp_extension import CUDA_HOME + TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False @@ -123,7 +122,8 @@ def check_runtime_build_environment(self): if not TORCH_AVAILABLE: raise ModuleNotFoundError( - "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions") + "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" + ) if CUDA_HOME is None: raise RuntimeError( @@ -150,7 +150,7 @@ def load(self, verbose: Optional[bool] = None): verbose (bool, optional): show detailed info. Defaults to True. """ if verbose is None: - verbose = os.environ.get('CAI_KERNEL_VERBOSE', '0') == '1' + verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" # if the kernel has be compiled and cached, we directly use it if self.cached_op_module is not None: return self.cached_op_module @@ -161,7 +161,8 @@ def load(self, verbose: Optional[bool] = None): op_module = self.import_op() if verbose: print_rank_0( - f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building.") + f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." + ) except ImportError: # check environment self.check_runtime_build_environment() @@ -172,10 +173,11 @@ def load(self, verbose: Optional[bool] = None): # construct the build directory import torch from torch.utils.cpp_extension import load - torch_version_major = torch.__version__.split('.')[0] - torch_version_minor = torch.__version__.split('.')[1] + + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] torch_cuda_version = torch.version.cuda - home_directory = os.path.expanduser('~') + home_directory = os.path.expanduser("~") extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" build_directory = os.path.join(home_directory, extension_directory) Path(build_directory).mkdir(parents=True, exist_ok=True) @@ -184,14 +186,16 @@ def load(self, verbose: Optional[bool] = None): print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") # load the kernel - op_module = load(name=self.name, - sources=self.strip_empty_entries(self.sources_files()), - extra_include_paths=self.strip_empty_entries(self.include_dirs()), - extra_cflags=self.cxx_flags(), - extra_cuda_cflags=self.nvcc_flags(), - extra_ldflags=[], - build_directory=build_directory, - verbose=verbose) + op_module = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=build_directory, + verbose=verbose, + ) build_duration = time.time() - start_build @@ -204,16 +208,18 @@ def load(self, verbose: Optional[bool] = None): return op_module - def builder(self) -> 'CUDAExtension': + def builder(self) -> "CUDAExtension": """ get a CUDAExtension instance used for setup.py """ from torch.utils.cpp_extension import CUDAExtension - return CUDAExtension(name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args={ - 'cxx': self.strip_empty_entries(self.cxx_flags()), - 'nvcc': self.strip_empty_entries(self.nvcc_flags()) - }) + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 500e2cc0eddc..5a2a2e3e6a56 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads @@ -10,29 +8,29 @@ class CPUAdamBuilder(Builder): def __init__(self): super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path('cpu_adam.cpp'), + self.csrc_abs_path("cpu_adam.cpp"), ] return ret def include_dirs(self): - return [ - self.csrc_abs_path("includes"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] def cxx_flags(self): - extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] - return ['-O3'] + self.version_dependent_macros + extra_cxx_flags + extra_cxx_flags = ["-std=c++14", "-lcudart", "-lcublas", "-g", "-Wno-reorder", "-fopenmp", "-march=native"] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags def nvcc_flags(self): extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py index 31ddfced1db2..3baa0880d801 100644 --- a/op_builder/fused_optim.py +++ b/op_builder/fused_optim.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import get_cuda_cc_flag @@ -10,25 +8,30 @@ class FusedOptimBuilder(Builder): def __init__(self): super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) - + def sources_files(self): ret = [ - self.csrc_abs_path(fname) for fname in [ - 'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu', - 'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu' + self.csrc_abs_path(fname) + for fname in [ + "colossal_C_frontend.cpp", + "multi_tensor_sgd_kernel.cu", + "multi_tensor_scale_kernel.cu", + "multi_tensor_adam.cu", + "multi_tensor_l2norm_kernel.cu", + "multi_tensor_lamb.cu", ] ] return ret def include_dirs(self): - ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def cxx_flags(self): - version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] - return ['-O3'] + version_dependent_macros + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros def nvcc_flags(self): - extra_cuda_flags = ['-lineinfo'] + extra_cuda_flags = ["-lineinfo"] extra_cuda_flags.extend(get_cuda_cc_flag()) - return ['-O3', '--use_fast_math'] + extra_cuda_flags + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py index 61d941741929..2684c6ddb7f7 100644 --- a/op_builder/layernorm.py +++ b/op_builder/layernorm.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag @@ -12,18 +10,18 @@ def __init__(self): super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu']] + ret = [self.csrc_abs_path(fname) for fname in ["layer_norm_cuda.cpp", "layer_norm_cuda_kernel.cu"]] return ret def include_dirs(self): - ret = [self.csrc_abs_path('kernels/include'), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): - extra_cuda_flags = ['-maxrregcount=50'] + extra_cuda_flags = ["-maxrregcount=50"] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + self.version_dependent_macros + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros return append_nvcc_threads(ret) diff --git a/op_builder/moe.py b/op_builder/moe.py index eeb7d8e3980c..6f8028b1720c 100644 --- a/op_builder/moe.py +++ b/op_builder/moe.py @@ -1,11 +1,8 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag class MOEBuilder(Builder): - NAME = "moe" PREBUILT_IMPORT_PATH = "colossalai._C.moe" @@ -13,24 +10,23 @@ def __init__(self): super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): - ret = [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ['moe_cuda.cpp', 'moe_cuda_kernel.cu']] + ret = [self.csrc_abs_path(fname) for fname in ["moe_cuda.cpp", "moe_cuda_kernel.cu"]] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', - '--expt-extended-lambda' + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py index f9103fe94729..b70f041db7d6 100644 --- a/op_builder/multi_head_attn.py +++ b/op_builder/multi_head_attn.py @@ -1,18 +1,13 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag class MultiHeadAttnBuilder(Builder): - NAME = "multihead_attention" PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" def __init__(self): - super().__init__(name=MultiHeadAttnBuilder.NAME, - prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) - + super().__init__(name=MultiHeadAttnBuilder.NAME, prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] @@ -20,22 +15,31 @@ def include_dirs(self): def sources_files(self): ret = [ - self.csrc_abs_path(fname) for fname in [ - 'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu', - 'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu', - 'kernels/general_kernels.cu', 'kernels/cuda_util.cu' + self.csrc_abs_path(fname) + for fname in [ + "multihead_attention_1d.cpp", + "kernels/cublas_wrappers.cu", + "kernels/transform_kernels.cu", + "kernels/dropout_kernels.cu", + "kernels/normalize_kernels.cu", + "kernels/softmax_kernels.cu", + "kernels/general_kernels.cu", + "kernels/cuda_util.cu", ] ] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/scaled_masked_softmax.py b/op_builder/scaled_masked_softmax.py index 11cfda39a85c..b2f1de7792c8 100644 --- a/op_builder/scaled_masked_softmax.py +++ b/op_builder/scaled_masked_softmax.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads @@ -9,29 +7,28 @@ class ScaledMaskedSoftmaxBuilder(Builder): PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" def __init__(self): - super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + super().__init__( + name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH + ) # necessary 4 functions def sources_files(self): - ret = [ - self.csrc_abs_path(fname) for fname in - ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'] - ] + ret = [self.csrc_abs_path(fname) for fname in ["scaled_masked_softmax.cpp", "scaled_masked_softmax_cuda.cu"]] return ret def include_dirs(self): - return [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", ] - ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py index d0d2433aa645..1445230acbc1 100644 --- a/op_builder/scaled_upper_triangle_masked_softmax.py +++ b/op_builder/scaled_upper_triangle_masked_softmax.py @@ -1,5 +1,3 @@ -import os - from .builder import Builder from .utils import append_nvcc_threads, get_cuda_cc_flag @@ -9,29 +7,31 @@ class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" def __init__(self): - super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + super().__init__( + name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, + prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH, + ) def include_dirs(self): - return [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] def sources_files(self): ret = [ self.csrc_abs_path(fname) - for fname in ['scaled_upper_triang_masked_softmax.cpp', 'scaled_upper_triang_masked_softmax_cuda.cu'] + for fname in ["scaled_upper_triang_masked_softmax.cpp", "scaled_upper_triang_masked_softmax_cuda.cu"] ] return ret def cxx_flags(self): - return ['-O3'] + self.version_dependent_macros + return ["-O3"] + self.version_dependent_macros def nvcc_flags(self): extra_cuda_flags = [ - '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', - '--expt-extended-lambda' + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", ] extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ['-O3', '--use_fast_math'] + extra_cuda_flags + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags return append_nvcc_threads(ret) diff --git a/op_builder/utils.py b/op_builder/utils.py index 9412c725baab..3f75f952d57b 100644 --- a/op_builder/utils.py +++ b/op_builder/utils.py @@ -11,6 +11,7 @@ def print_rank_0(message: str) -> None: """ try: import torch.distributed as dist + if not dist.is_initialized(): is_main_rank = True else: @@ -36,7 +37,8 @@ def get_cuda_version_in_pytorch() -> List[int]: torch_cuda_minor = torch.version.cuda.split(".")[1] except: raise ValueError( - "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda") + "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda" + ) return torch_cuda_major, torch_cuda_minor @@ -50,7 +52,7 @@ def get_cuda_bare_metal_version(cuda_dir) -> List[int]: Returns: The CUDA version required by PyTorch, in the form of tuple (major, minor). """ - nvcc_path = os.path.join(cuda_dir, 'bin/nvcc') + nvcc_path = os.path.join(cuda_dir, "bin/nvcc") if cuda_dir is None: raise ValueError( @@ -85,9 +87,9 @@ def check_system_pytorch_cuda_match(cuda_dir): if bare_metal_major != torch_cuda_major: raise Exception( - f'[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) ' - f'mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor}).' - 'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .' + f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) " + f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})." + "Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ." ) if bare_metal_minor != torch_cuda_minor: @@ -107,10 +109,11 @@ def get_pytorch_version() -> List[int]: A tuple of integers in the form of (major, minor, patch). """ import torch - torch_version = torch.__version__.split('+')[0] - TORCH_MAJOR = int(torch_version.split('.')[0]) - TORCH_MINOR = int(torch_version.split('.')[1]) - TORCH_PATCH = int(torch_version.split('.')[2], 16) + + torch_version = torch.__version__.split("+")[0] + TORCH_MAJOR = int(torch_version.split(".")[0]) + TORCH_MINOR = int(torch_version.split(".")[1]) + TORCH_PATCH = int(torch_version.split(".")[2], 16) return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH @@ -132,7 +135,8 @@ def check_pytorch_version(min_major_version, min_minor_version) -> bool: if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version): raise RuntimeError( f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n" - "The latest stable release can be obtained from https://pytorch.org/get-started/locally/") + "The latest stable release can be obtained from https://pytorch.org/get-started/locally/" + ) def check_cuda_availability(): @@ -143,6 +147,7 @@ def check_cuda_availability(): A boolean value. True if CUDA is available and False otherwise. """ import torch + return torch.cuda.is_available() @@ -155,29 +160,31 @@ def set_cuda_arch_list(cuda_dir): # we only need to set this when CUDA is not available for cross-compilation if not cuda_available: - warnings.warn('\n[extension] PyTorch did not find available GPUs on this system.\n' - 'If your intention is to cross-compile, this is not an error.\n' - 'By default, Colossal-AI will cross-compile for \n' - '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n' - '2. Volta (compute capability 7.0)\n' - '3. Turing (compute capability 7.5),\n' - '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n' - '\nIf you wish to cross-compile for a single specific architecture,\n' - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') + warnings.warn( + "\n[extension] PyTorch did not find available GPUs on this system.\n" + "If your intention is to cross-compile, this is not an error.\n" + "By default, Colossal-AI will cross-compile for \n" + "1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "2. Volta (compute capability 7.0)\n" + "3. Turing (compute capability 7.5),\n" + "4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n" + "\nIf you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n' + ) if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - arch_list = ['6.0', '6.1', '6.2', '7.0', '7.5'] + arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"] if int(bare_metal_major) == 11: if int(bare_metal_minor) == 0: - arch_list.append('8.0') + arch_list.append("8.0") else: - arch_list.append('8.0') - arch_list.append('8.6') + arch_list.append("8.0") + arch_list.append("8.6") - arch_list_str = ';'.join(arch_list) + arch_list_str = ";".join(arch_list) os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str return False return True @@ -197,13 +204,13 @@ def get_cuda_cc_flag() -> List[str]: import torch cc_flag = [] - max_arch = ''.join(str(i) for i in torch.cuda.get_device_capability()) + max_arch = "".join(str(i) for i in torch.cuda.get_device_capability()) for arch in torch.cuda.get_arch_list(): - res = re.search(r'sm_(\d+)', arch) + res = re.search(r"sm_(\d+)", arch) if res: arch_cap = res[1] if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch): - cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}']) + cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) return cc_flag diff --git a/setup.py b/setup.py index 5d8f831218d9..cda1ba7ee7a6 100644 --- a/setup.py +++ b/setup.py @@ -15,8 +15,8 @@ ) try: - import torch from torch.utils.cpp_extension import CUDA_HOME, BuildExtension + TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False @@ -26,14 +26,14 @@ MIN_PYTORCH_VERSION_MAJOR = 1 MIN_PYTORCH_VERSION_MINOR = 10 THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1 -IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1 +BUILD_CUDA_EXT = int(os.environ.get("CUDA_EXT", "0")) == 1 +IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1 # a variable to store the op builder ext_modules = [] # we do not support windows currently -if sys.platform == 'win32': +if sys.platform == "win32": raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).") @@ -64,7 +64,7 @@ def fetch_requirements(path) -> List[str]: Returns: The lines in the requirements file. """ - with open(path, 'r') as fd: + with open(path, "r") as fd: return [r.strip() for r in fd.readlines()] @@ -75,7 +75,7 @@ def fetch_readme() -> str: Returns: The lines in the README file. """ - with open('README.md', encoding='utf-8') as f: + with open("README.md", encoding="utf-8") as f: return f.read() @@ -89,21 +89,21 @@ def get_version() -> str: setup_file_path = os.path.abspath(__file__) project_path = os.path.dirname(setup_file_path) - version_txt_path = os.path.join(project_path, 'version.txt') - version_py_path = os.path.join(project_path, 'colossalai/version.py') + version_txt_path = os.path.join(project_path, "version.txt") + version_py_path = os.path.join(project_path, "colossalai/version.py") with open(version_txt_path) as f: version = f.read().strip() # write version into version.py - with open(version_py_path, 'w') as f: + with open(version_py_path, "w") as f: f.write(f"__version__ = '{version}'\n") # look for pytorch and cuda version if BUILD_CUDA_EXT: torch_major, torch_minor, _ = get_pytorch_version() - torch_version = f'{torch_major}.{torch_minor}' - cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)) + torch_version = f"{torch_major}.{torch_minor}" + cuda_version = ".".join(get_cuda_bare_metal_version(CUDA_HOME)) else: torch_version = None cuda_version = None @@ -112,12 +112,12 @@ def get_version() -> str: if torch_version: f.write(f'torch = "{torch_version}"\n') else: - f.write('torch = None\n') + f.write("torch = None\n") if cuda_version: f.write(f'cuda = "{cuda_version}"\n') else: - f.write('cuda = None\n') + f.write("cuda = None\n") return version @@ -127,6 +127,7 @@ def get_version() -> str: set_cuda_arch_list(CUDA_HOME) from op_builder import ALL_OPS + op_names = [] # load all builders @@ -135,7 +136,7 @@ def get_version() -> str: ext_modules.append(builder_cls().builder()) # show log - op_name_list = ', '.join(op_names) + op_name_list = ", ".join(op_names) print(f"[extension] loaded builders for {op_name_list}") # always put not nightly branch as the if branch @@ -143,56 +144,62 @@ def get_version() -> str: # and it will mess up with the dependency graph insights if not IS_NIGHTLY: version = get_version() - package_name = 'colossalai' + package_name = "colossalai" else: # use date as the nightly version - version = datetime.today().strftime('%Y.%m.%d') - package_name = 'colossalai-nightly' - -setup(name=package_name, - version=version, - packages=find_packages(exclude=( - 'op_builder', - 'benchmark', - 'docker', - 'tests', - 'docs', - 'examples', - 'tests', - 'scripts', - 'requirements', - '*.egg-info', - )), - description='An integrated large-scale model training system with efficient parallelization techniques', - long_description=fetch_readme(), - long_description_content_type='text/markdown', - license='Apache Software License 2.0', - url='https://www.colossalai.org', - project_urls={ - 'Forum': 'https://github.com/hpcaitech/ColossalAI/discussions', - 'Bug Tracker': 'https://github.com/hpcaitech/ColossalAI/issues', - 'Examples': 'https://github.com/hpcaitech/ColossalAI-Examples', - 'Documentation': 'http://colossalai.readthedocs.io', - 'Github': 'https://github.com/hpcaitech/ColossalAI', - }, - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension} if ext_modules else {}, - install_requires=fetch_requirements('requirements/requirements.txt'), - entry_points=''' + version = datetime.today().strftime("%Y.%m.%d") + package_name = "colossalai-nightly" + +setup( + name=package_name, + version=version, + packages=find_packages( + exclude=( + "op_builder", + "benchmark", + "docker", + "tests", + "docs", + "examples", + "tests", + "scripts", + "requirements", + "*.egg-info", + ) + ), + description="An integrated large-scale model training system with efficient parallelization techniques", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://www.colossalai.org", + project_urls={ + "Forum": "https://github.com/hpcaitech/ColossalAI/discussions", + "Bug Tracker": "https://github.com/hpcaitech/ColossalAI/issues", + "Examples": "https://github.com/hpcaitech/ColossalAI-Examples", + "Documentation": "http://colossalai.readthedocs.io", + "Github": "https://github.com/hpcaitech/ColossalAI", + }, + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension} if ext_modules else {}, + install_requires=fetch_requirements("requirements/requirements.txt"), + entry_points=""" [console_scripts] colossalai=colossalai.cli:cli - ''', - python_requires='>=3.6', - classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Environment :: GPU :: NVIDIA CUDA', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: System :: Distributed Computing', - ], - package_data={ - 'colossalai': [ - '_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', - 'kernel/cuda_native/csrc/kernels/include/*' - ] - }) + """, + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], + package_data={ + "colossalai": [ + "_C/*.pyi", + "kernel/cuda_native/csrc/*", + "kernel/cuda_native/csrc/kernel/*", + "kernel/cuda_native/csrc/kernels/include/*", + ] + }, +) diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py index f29efefce4a4..65eaa72d6e84 100644 --- a/tests/components_to_test/__init__.py +++ b/tests/components_to_test/__init__.py @@ -11,9 +11,19 @@ ) from .utils import run_fwd, run_fwd_bwd -from . import albert # isort:skip +from . import albert # isort:skip __all__ = [ - 'bert', 'gpt2', 'hanging_param_model', 'inline_op_model', 'nested_model', 'repeated_computed_layers', 'resnet', - 'simple_net', 'run_fwd_bwd', 'albert', 'beit', 'run_fwd' + "bert", + "gpt2", + "hanging_param_model", + "inline_op_model", + "nested_model", + "repeated_computed_layers", + "resnet", + "simple_net", + "run_fwd_bwd", + "albert", + "beit", + "run_fwd", ] diff --git a/tests/components_to_test/albert.py b/tests/components_to_test/albert.py index 8924eb2fbc92..0ba4d19655cd 100644 --- a/tests/components_to_test/albert.py +++ b/tests/components_to_test/albert.py @@ -1,13 +1,11 @@ import torch -import transformers -from packaging import version from transformers import AlbertConfig, AlbertForSequenceClassification from .bert import get_bert_data_loader from .registry import non_distributed_component_funcs -@non_distributed_component_funcs.register(name='albert') +@non_distributed_component_funcs.register(name="albert") def get_training_components(): hidden_dim = 8 num_head = 4 @@ -16,20 +14,21 @@ def get_training_components(): vocab_size = 32 def bert_model_builder(checkpoint: bool = False): - config = AlbertConfig(vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0., - attention_probs_dropout_prob=0.) - print('building AlbertForSequenceClassification model') + config = AlbertConfig( + vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + ) + print("building AlbertForSequenceClassification model") # adapting huggingface BertForSequenceClassification for single unittest calling interface class ModelAdaptor(AlbertForSequenceClassification): - def forward(self, input_ids, labels): """ inputs: data, label @@ -44,16 +43,20 @@ def forward(self, input_ids, labels): return model is_distributed = torch.distributed.is_initialized() - trainloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed) - testloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed) + trainloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) + testloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) criterion = None return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/beit.py b/tests/components_to_test/beit.py index 2021ae6f6e35..d33474ea9a6b 100644 --- a/tests/components_to_test/beit.py +++ b/tests/components_to_test/beit.py @@ -14,25 +14,27 @@ class DummyDataLoader(DummyDataGenerator): batch_size = 4 def generate(self): - data = torch.randn((DummyDataLoader.batch_size, DummyDataLoader.num_channel, DummyDataLoader.img_size, - DummyDataLoader.img_size), - device=get_current_device()) - label = torch.randint(low=0, - high=DummyDataLoader.num_class, - size=(DummyDataLoader.batch_size,), - device=get_current_device()) + data = torch.randn( + ( + DummyDataLoader.batch_size, + DummyDataLoader.num_channel, + DummyDataLoader.img_size, + DummyDataLoader.img_size, + ), + device=get_current_device(), + ) + label = torch.randint( + low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device() + ) return data, label -@non_distributed_component_funcs.register(name='beit') +@non_distributed_component_funcs.register(name="beit") def get_training_components(): - def model_builder(checkpoint=False): - model = Beit(img_size=DummyDataLoader.img_size, - num_classes=DummyDataLoader.num_class, - embed_dim=32, - depth=2, - num_heads=4) + model = Beit( + img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4 + ) return model trainloader = DummyDataLoader() diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index e7d1d50806b8..f0061ad18c84 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -8,12 +8,12 @@ def get_bert_data_loader( - n_class, - batch_size, - total_samples, - sequence_length, - device=torch.device('cpu:0'), - is_distributed=False, + n_class, + batch_size, + total_samples, + sequence_length, + device=torch.device("cpu:0"), + is_distributed=False, ): train_data = torch.randint( low=0, @@ -32,7 +32,7 @@ def get_bert_data_loader( return train_loader -@non_distributed_component_funcs.register(name='bert') +@non_distributed_component_funcs.register(name="bert") def get_training_components(): hidden_dim = 8 num_head = 4 @@ -41,20 +41,21 @@ def get_training_components(): vocab_size = 32 def bert_model_builder(checkpoint: bool = False): - config = BertConfig(vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0., - attention_probs_dropout_prob=0.) - print('building BertForSequenceClassification model') + config = BertConfig( + vocab_size=vocab_size, + gradient_checkpointing=checkpoint, + hidden_size=hidden_dim, + intermediate_size=hidden_dim * 4, + num_attention_heads=num_head, + max_position_embeddings=sequence_length, + num_hidden_layers=num_layer, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + ) + print("building BertForSequenceClassification model") # adapting huggingface BertForSequenceClassification for single unittest calling interface class ModelAdaptor(BertForSequenceClassification): - def forward(self, input_ids, labels): """ inputs: data, label @@ -69,16 +70,20 @@ def forward(self, input_ids, labels): return model is_distributed = torch.distributed.is_initialized() - trainloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed) - testloader = get_bert_data_loader(n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed) + trainloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) + testloader = get_bert_data_loader( + n_class=vocab_size, + batch_size=2, + total_samples=10000, + sequence_length=sequence_length, + is_distributed=is_distributed, + ) criterion = None return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/gpt2.py b/tests/components_to_test/gpt2.py index fe25b4923fa2..7f826497d2ab 100644 --- a/tests/components_to_test/gpt2.py +++ b/tests/components_to_test/gpt2.py @@ -14,33 +14,40 @@ class DummyDataLoader(DummyDataGenerator): seq_len = 64 def generate(self): - input_ids = torch.randint(0, - DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), - device=get_current_device()) + input_ids = torch.randint( + 0, + DummyDataLoader.vocab_size, + (DummyDataLoader.batch_size, DummyDataLoader.seq_len), + device=get_current_device(), + ) return input_ids, input_ids class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50304, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50304, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size, - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + ) + ) if checkpoint: self.model.gradient_checkpointing_enable() @@ -51,12 +58,9 @@ def forward(self, input_ids): def gpt2_micro(checkpoint=True): - return GPTLMModel(checkpoint=checkpoint, - hidden_size=32, - num_layers=2, - num_attention_heads=4, - max_seq_len=64, - vocab_size=128) + return GPTLMModel( + checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128 + ) def gpt2_s(checkpoint=True): @@ -68,7 +72,6 @@ def gpt2_m(checkpoint=True): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -80,9 +83,8 @@ def forward(self, logits, labels): return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) -@non_distributed_component_funcs.register(name='gpt2') +@non_distributed_component_funcs.register(name="gpt2") def get_training_components(): - trainloader = DummyDataLoader() testloader = DummyDataLoader() diff --git a/tests/components_to_test/hanging_param_model.py b/tests/components_to_test/hanging_param_model.py index 0e65431217c7..5531c8d081a0 100644 --- a/tests/components_to_test/hanging_param_model.py +++ b/tests/components_to_test/hanging_param_model.py @@ -28,16 +28,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 4) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='hanging_param_model') +@non_distributed_component_funcs.register(name="hanging_param_model") def get_training_components(): - def model_builder(checkpoint=False): return HangingParamModule(checkpoint) @@ -46,4 +44,5 @@ def model_builder(checkpoint=False): criterion = torch.nn.CrossEntropyLoss() from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py index 80757f361d9e..8bfa9cf34353 100644 --- a/tests/components_to_test/inline_op_model.py +++ b/tests/components_to_test/inline_op_model.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F from colossalai.legacy.nn import CheckpointModule @@ -19,7 +18,6 @@ def __init__(self, checkpoint=False) -> None: self.proj2 = nn.Linear(8, 8) def forward(self, x): - x = self.proj1(x) # inline add_ x.add_(10) @@ -31,16 +29,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 4) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='inline_op_model') +@non_distributed_component_funcs.register(name="inline_op_model") def get_training_components(): - def model_builder(checkpoint=False): return InlineOpModule(checkpoint) @@ -49,4 +45,5 @@ def model_builder(checkpoint=False): criterion = torch.nn.CrossEntropyLoss() from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/nested_model.py b/tests/components_to_test/nested_model.py index 3e779b0a6428..44577456dec5 100644 --- a/tests/components_to_test/nested_model.py +++ b/tests/components_to_test/nested_model.py @@ -9,7 +9,6 @@ class SubNet(nn.Module): - def __init__(self, out_features) -> None: super().__init__() self.bias = nn.Parameter(torch.zeros(out_features)) @@ -19,7 +18,6 @@ def forward(self, x, weight): class NestedNet(CheckpointModule): - def __init__(self, checkpoint=False) -> None: super().__init__(checkpoint) self.fc1 = nn.Linear(5, 5) @@ -35,16 +33,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 5) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='nested_model') +@non_distributed_component_funcs.register(name="nested_model") def get_training_components(): - def model_builder(checkpoint=False): return NestedNet(checkpoint) diff --git a/tests/components_to_test/registry.py b/tests/components_to_test/registry.py index edfcaaa7275b..ec561b7831ad 100644 --- a/tests/components_to_test/registry.py +++ b/tests/components_to_test/registry.py @@ -2,7 +2,6 @@ class Registry: - def __init__(self): self._registry = dict() @@ -36,4 +35,4 @@ def __next__(self): non_distributed_component_funcs = Registry() model_parallel_component_funcs = Registry() -__all__ = ['non_distributed_component_funcs', 'model_parallel_component_funcs'] +__all__ = ["non_distributed_component_funcs", "model_parallel_component_funcs"] diff --git a/tests/components_to_test/repeated_computed_layers.py b/tests/components_to_test/repeated_computed_layers.py index c1ef99aa07b4..3da64de3fb64 100644 --- a/tests/components_to_test/repeated_computed_layers.py +++ b/tests/components_to_test/repeated_computed_layers.py @@ -29,16 +29,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.rand(16, 5) label = torch.randint(low=0, high=2, size=(16,)) return data, label -@non_distributed_component_funcs.register(name='repeated_computed_layers') +@non_distributed_component_funcs.register(name="repeated_computed_layers") def get_training_components(): - def model_builder(checkpoint=False): return NetWithRepeatedlyComputedLayers(checkpoint) diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py index df01e4c4847e..a43becc16233 100644 --- a/tests/components_to_test/resnet.py +++ b/tests/components_to_test/resnet.py @@ -13,19 +13,20 @@ def get_cifar10_dataloader(train): # build dataloaders - dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - train=train, - transform=transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])) + dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + train=train, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))] + ), + ) dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True) return dataloader -@non_distributed_component_funcs.register(name='resnet18') +@non_distributed_component_funcs.register(name="resnet18") def get_resnet_training_components(): - def model_builder(checkpoint=False): return resnet18(num_classes=10) diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py index 064974a15a97..0f0ac5cff49a 100644 --- a/tests/components_to_test/simple_net.py +++ b/tests/components_to_test/simple_net.py @@ -33,16 +33,14 @@ def forward(self, x): class DummyDataLoader(DummyDataGenerator): - def generate(self): data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) label = torch.randint(low=0, high=2, size=(16,), device=get_current_device()) return data, label -@non_distributed_component_funcs.register(name='simple_net') +@non_distributed_component_funcs.register(name="simple_net") def get_training_components(): - def model_builder(checkpoint=False): return SimpleNet(checkpoint) @@ -51,4 +49,5 @@ def model_builder(checkpoint=False): criterion = torch.nn.CrossEntropyLoss() from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/utils/dummy_data_generator.py b/tests/components_to_test/utils/dummy_data_generator.py index 5ab33e86de23..7b3af46c8f35 100644 --- a/tests/components_to_test/utils/dummy_data_generator.py +++ b/tests/components_to_test/utils/dummy_data_generator.py @@ -2,7 +2,6 @@ class DummyDataGenerator(ABC): - def __init__(self, length=10): self.length = length diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index 466a2a558829..c08fd365d871 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,4 @@ from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers from .registry import model_zoo -__all__ = ['model_zoo'] +__all__ = ["model_zoo"] diff --git a/tests/kit/model_zoo/diffusers/diffusers.py b/tests/kit/model_zoo/diffusers/diffusers.py index 204c1d7773ca..895ee7967f6b 100644 --- a/tests/kit/model_zoo/diffusers/diffusers.py +++ b/tests/kit/model_zoo/diffusers/diffusers.py @@ -4,7 +4,7 @@ import torch import transformers -from ..registry import ModelAttribute, model_zoo +from ..registry import model_zoo BATCH_SIZE = 2 SEQ_LENGTH = 5 @@ -26,10 +26,9 @@ def data_clip_model(): attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) - return dict(input_ids=input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - position_ids=position_ids) + return dict( + input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, position_ids=position_ids + ) def data_clip_text(): @@ -43,32 +42,41 @@ def data_clip_vision(): return dict(pixel_values=pixel_values) -model_zoo.register(name='diffusers_auto_encoder_kl', - model_fn=diffusers.AutoencoderKL, - data_gen_fn=data_vae_fn, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_vq_model', - model_fn=diffusers.VQModel, - data_gen_fn=data_vae_fn, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_clip_model', - model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()), - data_gen_fn=data_clip_model, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_clip_text_model', - model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()), - data_gen_fn=data_clip_text, - output_transform_fn=identity_output) - -model_zoo.register(name='diffusers_clip_vision_model', - model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()), - data_gen_fn=data_clip_vision, - output_transform_fn=clip_vision_model_output) - -model_zoo.register(name='diffusers_unet2d_model', - model_fn=diffusers.UNet2DModel, - data_gen_fn=data_unet_fn, - output_transform_fn=identity_output) +model_zoo.register( + name="diffusers_auto_encoder_kl", + model_fn=diffusers.AutoencoderKL, + data_gen_fn=data_vae_fn, + output_transform_fn=identity_output, +) + +model_zoo.register( + name="diffusers_vq_model", model_fn=diffusers.VQModel, data_gen_fn=data_vae_fn, output_transform_fn=identity_output +) + +model_zoo.register( + name="diffusers_clip_model", + model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()), + data_gen_fn=data_clip_model, + output_transform_fn=identity_output, +) + +model_zoo.register( + name="diffusers_clip_text_model", + model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()), + data_gen_fn=data_clip_text, + output_transform_fn=identity_output, +) + +model_zoo.register( + name="diffusers_clip_vision_model", + model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()), + data_gen_fn=data_clip_vision, + output_transform_fn=clip_vision_model_output, +) + +model_zoo.register( + name="diffusers_unet2d_model", + model_fn=diffusers.UNet2DModel, + data_gen_fn=data_unet_fn, + output_transform_fn=identity_output, +) diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 1e7ef3b62736..b90972291870 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Callable -__all__ = ['ModelZooRegistry', 'ModelAttribute', 'model_zoo'] +__all__ = ["ModelZooRegistry", "ModelAttribute", "model_zoo"] @dataclass @@ -14,6 +14,7 @@ class ModelAttribute: has_control_flow (bool): Whether the model contains branching in its forward method. has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models. """ + has_control_flow: bool = False has_stochastic_depth_prob: bool = False @@ -23,13 +24,15 @@ class ModelZooRegistry(dict): A registry to map model names to model and data generation functions. """ - def register(self, - name: str, - model_fn: Callable, - data_gen_fn: Callable, - output_transform_fn: Callable, - loss_fn: Callable = None, - model_attribute: ModelAttribute = None): + def register( + self, + name: str, + model_fn: Callable, + data_gen_fn: Callable, + output_transform_fn: Callable, + loss_fn: Callable = None, + model_attribute: ModelAttribute = None, + ): """ Register a model and data generation function. @@ -71,7 +74,7 @@ def get_sub_registry(self, keyword: str): if keyword in k: new_dict[k] = v - assert len(new_dict) > 0, f'No model found with keyword {keyword}' + assert len(new_dict) > 0, f"No model found with keyword {keyword}" return new_dict diff --git a/tests/kit/model_zoo/timm/timm.py b/tests/kit/model_zoo/timm/timm.py index b29ac12a6b53..eb6d2f6bc757 100644 --- a/tests/kit/model_zoo/timm/timm.py +++ b/tests/kit/model_zoo/timm/timm.py @@ -9,151 +9,183 @@ data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224)) output_transform_fn = lambda x: dict(output=x) -model_zoo.register(name='timm_resnet', - model_fn=tm.resnest.resnest50d, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_beit', - model_fn=tm.beit.beit_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_cait', - model_fn=tm.cait.cait_s24_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_convmixer', - model_fn=tm.convmixer.convmixer_768_32, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_efficientnetv2', - model_fn=tm.efficientnet.efficientnetv2_m, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_resmlp', - model_fn=tm.resmlp_12_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_vision_transformer', - model_fn=tm.vision_transformer.vit_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_deit', - model_fn=tm.deit_base_distilled_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_beitv2', - model_fn=tm.beitv2_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_coat', - model_fn=tm.coat.coat_lite_mini, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="timm_resnet", model_fn=tm.resnest.resnest50d, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_beit", + model_fn=tm.beit.beit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_cait", model_fn=tm.cait.cait_s24_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_convmixer", + model_fn=tm.convmixer.convmixer_768_32, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_efficientnetv2", + model_fn=tm.efficientnet.efficientnetv2_m, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_resmlp", model_fn=tm.resmlp_12_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_vision_transformer", + model_fn=tm.vision_transformer.vit_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_deit", + model_fn=tm.deit_base_distilled_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_beitv2", + model_fn=tm.beitv2_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_coat", model_fn=tm.coat.coat_lite_mini, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) -model_zoo.register(name='timm_deit3', - model_fn=tm.deit3_base_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="timm_deit3", + model_fn=tm.deit3_base_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) -model_zoo.register(name='timm_eca_nfnet', - model_fn=tm.eca_nfnet_l0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_efficientformer', - model_fn=tm.efficientformer_l1, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_ese_vovnet19b_dw', - model_fn=tm.ese_vovnet19b_dw, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_gmixer_12_224', - model_fn=tm.gmixer_12_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_gmlp_b16_224', - model_fn=tm.gmlp_b16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_hardcorenas_a', - model_fn=tm.hardcorenas_a, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_hrnet_w18_small', - model_fn=tm.hrnet_w18_small, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_inception_v3', - model_fn=tm.inception_v3, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_mixer_b16_224', - model_fn=tm.mixer_b16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_nf_ecaresnet101', - model_fn=tm.nf_ecaresnet101, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_nf_regnet_b0', - model_fn=tm.nf_regnet_b0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_regnetv_040', - model_fn=tm.regnetv_040, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_skresnet18', - model_fn=tm.skresnet18, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_tnt_b_patch16_224', - model_fn=tm.tnt_b_patch16_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_wide_resnet50_2', - model_fn=tm.wide_resnet50_2, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_convit', - model_fn=tm.convit_base, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='timm_dm_nfnet', - model_fn=tm.dm_nfnet_f0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="timm_eca_nfnet", model_fn=tm.eca_nfnet_l0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_efficientformer", + model_fn=tm.efficientformer_l1, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_ese_vovnet19b_dw", + model_fn=tm.ese_vovnet19b_dw, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_gmixer_12_224", + model_fn=tm.gmixer_12_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_gmlp_b16_224", model_fn=tm.gmlp_b16_224, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_hardcorenas_a", + model_fn=tm.hardcorenas_a, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_hrnet_w18_small", + model_fn=tm.hrnet_w18_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_inception_v3", model_fn=tm.inception_v3, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_mixer_b16_224", + model_fn=tm.mixer_b16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_nf_ecaresnet101", + model_fn=tm.nf_ecaresnet101, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_nf_regnet_b0", model_fn=tm.nf_regnet_b0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_regnetv_040", model_fn=tm.regnetv_040, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_skresnet18", model_fn=tm.skresnet18, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_tnt_b_patch16_224", + model_fn=tm.tnt_b_patch16_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_wide_resnet50_2", + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="timm_convit", model_fn=tm.convit_base, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="timm_dm_nfnet", model_fn=tm.dm_nfnet_f0, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) # ============== # Register models with control flow # ============== -model_zoo.register(name='timm_convnext', - model_fn=tm.convnext.convnext_base, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_vgg', - model_fn=tm.vgg.vgg11, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_dpn', - model_fn=tm.dpn.dpn68, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_densenet', - model_fn=tm.densenet.densenet121, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_rexnet', - model_fn=tm.rexnet.rexnet_100, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='timm_swin_transformer', - model_fn=tm.swin_transformer.swin_base_patch4_window7_224, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="timm_convnext", + model_fn=tm.convnext.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_vgg", + model_fn=tm.vgg.vgg11, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_dpn", + model_fn=tm.dpn.dpn68, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_densenet", + model_fn=tm.densenet.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_rexnet", + model_fn=tm.rexnet.rexnet_100, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="timm_swin_transformer", + model_fn=tm.swin_transformer.swin_base_patch4_window7_224, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/torchaudio/torchaudio.py b/tests/kit/model_zoo/torchaudio/torchaudio.py index 9a244ac312c0..03f565c04553 100644 --- a/tests/kit/model_zoo/torchaudio/torchaudio.py +++ b/tests/kit/model_zoo/torchaudio/torchaudio.py @@ -23,24 +23,31 @@ def conformer_data_gen_fn(): transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1]) -model_zoo.register(name='torchaudio_conformer', - model_fn=lambda: tm.Conformer( - input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31), - data_gen_fn=conformer_data_gen_fn, - output_transform_fn=transformer_output_transform_fn) +model_zoo.register( + name="torchaudio_conformer", + model_fn=lambda: tm.Conformer( + input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31 + ), + data_gen_fn=conformer_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, +) single_output_transform_fn = lambda output: dict(output=output) -model_zoo.register(name='torchaudio_convtasnet', - model_fn=tm.ConvTasNet, - data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)), - output_transform_fn=single_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_convtasnet", + model_fn=tm.ConvTasNet, + data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)), + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='torchaudio_deepspeech', - model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4), - data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)), - output_transform_fn=single_output_transform_fn) +model_zoo.register( + name="torchaudio_deepspeech", + model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4), + data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)), + output_transform_fn=single_output_transform_fn, +) def emformer_data_gen_fn(): @@ -50,21 +57,26 @@ def emformer_data_gen_fn(): model_zoo.register( - name='torchaudio_emformer', + name="torchaudio_emformer", model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4), data_gen_fn=emformer_data_gen_fn, output_transform_fn=transformer_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='torchaudio_wav2letter_waveform', - model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40), - data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), - output_transform_fn=single_output_transform_fn) +model_zoo.register( + name="torchaudio_wav2letter_waveform", + model_fn=lambda: tm.Wav2Letter(input_type="waveform", num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn, +) -model_zoo.register(name='torchaudio_wav2letter_mfcc', - model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40), - data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), - output_transform_fn=single_output_transform_fn) +model_zoo.register( + name="torchaudio_wav2letter_mfcc", + model_fn=lambda: tm.Wav2Letter(input_type="mfcc", num_features=40), + data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)), + output_transform_fn=single_output_transform_fn, +) def wavernn_data_gen_fn(): @@ -73,20 +85,24 @@ def wavernn_data_gen_fn(): return dict(waveform=waveform, specgram=specgram) -model_zoo.register(name='torchaudio_wavernn', - model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5], - n_classes=N_CLASSES, - hop_length=HOP_LENGTH, - kernel_size=KERNEL_SIZE, - n_freq=N_FREQ, - n_res_block=2, - n_rnn=64, - n_fc=64, - n_hidden=16, - n_output=16), - data_gen_fn=wavernn_data_gen_fn, - output_transform_fn=single_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_wavernn", + model_fn=lambda: tm.WaveRNN( + upsample_scales=[2, 2, 5], + n_classes=N_CLASSES, + hop_length=HOP_LENGTH, + kernel_size=KERNEL_SIZE, + n_freq=N_FREQ, + n_res_block=2, + n_rnn=64, + n_fc=64, + n_hidden=16, + n_output=16, + ), + data_gen_fn=wavernn_data_gen_fn, + output_transform_fn=single_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) def tacotron_data_gen_fn(): @@ -97,17 +113,18 @@ def tacotron_data_gen_fn(): token_lengths = max_text_length * torch.ones((n_batch,)) mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length) mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,)) - return dict(tokens=tokens, - token_lengths=token_lengths, - mel_specgram=mel_specgram, - mel_specgram_lengths=mel_specgram_lengths) + return dict( + tokens=tokens, token_lengths=token_lengths, mel_specgram=mel_specgram, mel_specgram_lengths=mel_specgram_lengths + ) -model_zoo.register(name='torchaudio_tacotron', - model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), - data_gen_fn=tacotron_data_gen_fn, - output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)), - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_tacotron", + model_fn=lambda: tm.Tacotron2(n_mels=N_MELS), + data_gen_fn=tacotron_data_gen_fn, + output_transform_fn=lambda outputs: dict(summed_output=sum(x.sum() for x in outputs)), + model_attribute=ModelAttribute(has_control_flow=True), +) def wav2vec_data_gen_fn(): @@ -117,14 +134,18 @@ def wav2vec_data_gen_fn(): return dict(waveforms=waveforms, lengths=lengths) -model_zoo.register(name='torchaudio_wav2vec2_base', - model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0), - data_gen_fn=wav2vec_data_gen_fn, - output_transform_fn=transformer_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_wav2vec2_base", + model_fn=partial(tm.wav2vec2_base, encoder_layer_drop=0.0), + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='torchaudio_hubert_base', - model_fn=tm.hubert_base, - data_gen_fn=wav2vec_data_gen_fn, - output_transform_fn=transformer_output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="torchaudio_hubert_base", + model_fn=tm.hubert_base, + data_gen_fn=wav2vec_data_gen_fn, + output_transform_fn=transformer_output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/torchrec/torchrec.py b/tests/kit/model_zoo/torchrec/torchrec.py index dda563155fca..d4baf576d54b 100644 --- a/tests/kit/model_zoo/torchrec/torchrec.py +++ b/tests/kit/model_zoo/torchrec/torchrec.py @@ -1,4 +1,3 @@ -from collections import namedtuple from functools import partial import torch @@ -7,7 +6,7 @@ from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor -from ..registry import ModelAttribute, model_zoo +from ..registry import model_zoo BATCH = 2 SHAPE = 10 @@ -20,9 +19,9 @@ def gen_kt(): # KeyedJaggedTensor def gen_kjt(): - KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"], - values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), - offsets=torch.tensor([0, 2, 4, 6, 8])) + KJT = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2"], values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), offsets=torch.tensor([0, 2, 4, 6, 8]) + ) return KJT @@ -68,7 +67,7 @@ def get_ebc(): # EmbeddingBagCollection eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) - return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu')) + return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device("cpu")) def sparse_arch_model_fn(): @@ -91,52 +90,69 @@ def dlrm_sparsearch_model_fn(): return dlrm.SparseArch(ebc) -model_zoo.register(name='deepfm_densearch', - model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_interactionarch', - model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE), - data_gen_fn=interaction_arch_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_overarch', - model_fn=partial(deepfm.OverArch, SHAPE), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_simpledeepfmnn', - model_fn=simple_deep_fmnn_model_fn, - data_gen_fn=simple_dfm_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='deepfm_sparsearch', - model_fn=sparse_arch_model_fn, - data_gen_fn=sparse_arch_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm', - model_fn=dlrm_model_fn, - data_gen_fn=simple_dfm_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_densearch', - model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_interactionarch', - model_fn=partial(dlrm.InteractionArch, 2), - data_gen_fn=interaction_arch_data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_overarch', - model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - -model_zoo.register(name='dlrm_sparsearch', - model_fn=dlrm_sparsearch_model_fn, - data_gen_fn=sparse_arch_data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="deepfm_densearch", + model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_interactionarch", + model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_overarch", + model_fn=partial(deepfm.OverArch, SHAPE), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_simpledeepfmnn", + model_fn=simple_deep_fmnn_model_fn, + data_gen_fn=simple_dfm_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="deepfm_sparsearch", + model_fn=sparse_arch_model_fn, + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm", model_fn=dlrm_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn +) + +model_zoo.register( + name="dlrm_densearch", + model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm_interactionarch", + model_fn=partial(dlrm.InteractionArch, 2), + data_gen_fn=interaction_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm_overarch", + model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) + +model_zoo.register( + name="dlrm_sparsearch", + model_fn=dlrm_sparsearch_model_fn, + data_gen_fn=sparse_arch_data_gen_fn, + output_transform_fn=output_transform_fn, +) diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py index ddc3ec24b2ff..57b633e9d676 100644 --- a/tests/kit/model_zoo/torchvision/torchvision.py +++ b/tests/kit/model_zoo/torchvision/torchvision.py @@ -1,5 +1,3 @@ -from collections import namedtuple - import torch import torchvision import torchvision.models as tm @@ -29,103 +27,133 @@ def swin_s(): depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=[7, 7], - stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic + stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic weights=weights, progress=progress, ) # special output transform fn -google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs - ) else dict(output=x) -swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val - for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) -inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs - ) else dict(output=x) +google_net_output_transform_fn = ( + lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs) else dict(output=x) +) +swin_s_output_output_transform_fn = ( + lambda x: {f"output{idx}": val for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x) +) +inception_v3_output_transform_fn = ( + lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs) else dict(output=x) +) -model_zoo.register(name='torchvision_alexnet', - model_fn=tm.alexnet, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_densenet121', - model_fn=tm.densenet121, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_efficientnet_b0', - model_fn=tm.efficientnet_b0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) -model_zoo.register(name='torchvision_googlenet', - model_fn=tm.googlenet, - data_gen_fn=data_gen_fn, - output_transform_fn=google_net_output_transform_fn) -model_zoo.register(name='torchvision_inception_v3', - model_fn=tm.inception_v3, - data_gen_fn=inception_v3_data_gen_fn, - output_transform_fn=inception_v3_output_transform_fn) -model_zoo.register(name='torchvision_mobilenet_v2', - model_fn=tm.mobilenet_v2, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_mobilenet_v3_small', - model_fn=tm.mobilenet_v3_small, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_mnasnet0_5', - model_fn=tm.mnasnet0_5, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_resnet18', - model_fn=tm.resnet18, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_regnet_x_16gf', - model_fn=tm.regnet_x_16gf, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_resnext50_32x4d', - model_fn=tm.resnext50_32x4d, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_shufflenet_v2_x0_5', - model_fn=tm.shufflenet_v2_x0_5, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_squeezenet1_0', - model_fn=tm.squeezenet1_0, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="torchvision_alexnet", model_fn=tm.alexnet, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="torchvision_densenet121", + model_fn=tm.densenet121, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_efficientnet_b0", + model_fn=tm.efficientnet_b0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True), +) +model_zoo.register( + name="torchvision_googlenet", + model_fn=tm.googlenet, + data_gen_fn=data_gen_fn, + output_transform_fn=google_net_output_transform_fn, +) +model_zoo.register( + name="torchvision_inception_v3", + model_fn=tm.inception_v3, + data_gen_fn=inception_v3_data_gen_fn, + output_transform_fn=inception_v3_output_transform_fn, +) +model_zoo.register( + name="torchvision_mobilenet_v2", + model_fn=tm.mobilenet_v2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_mobilenet_v3_small", + model_fn=tm.mobilenet_v3_small, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_mnasnet0_5", + model_fn=tm.mnasnet0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_resnet18", model_fn=tm.resnet18, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="torchvision_regnet_x_16gf", + model_fn=tm.regnet_x_16gf, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_resnext50_32x4d", + model_fn=tm.resnext50_32x4d, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_shufflenet_v2_x0_5", + model_fn=tm.shufflenet_v2_x0_5, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) +model_zoo.register( + name="torchvision_squeezenet1_0", + model_fn=tm.squeezenet1_0, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) -model_zoo.register(name='torchvision_vgg11', - model_fn=tm.vgg11, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) -model_zoo.register(name='torchvision_wide_resnet50_2', - model_fn=tm.wide_resnet50_2, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) +model_zoo.register( + name="torchvision_vgg11", model_fn=tm.vgg11, data_gen_fn=data_gen_fn, output_transform_fn=output_transform_fn +) +model_zoo.register( + name="torchvision_wide_resnet50_2", + model_fn=tm.wide_resnet50_2, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, +) -if version.parse(torchvision.__version__) >= version.parse('0.12.0'): - model_zoo.register(name='torchvision_vit_b_16', - model_fn=tm.vit_b_16, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn) - model_zoo.register(name='torchvision_convnext_base', - model_fn=tm.convnext_base, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) +if version.parse(torchvision.__version__) >= version.parse("0.12.0"): + model_zoo.register( + name="torchvision_vit_b_16", + model_fn=tm.vit_b_16, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + ) + model_zoo.register( + name="torchvision_convnext_base", + model_fn=tm.convnext_base, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True), + ) -if version.parse(torchvision.__version__) >= version.parse('0.13.0'): +if version.parse(torchvision.__version__) >= version.parse("0.13.0"): model_zoo.register( - name='torchvision_swin_s', + name="torchvision_swin_s", model_fn=swin_s, data_gen_fn=data_gen_fn, output_transform_fn=swin_s_output_output_transform_fn, ) - model_zoo.register(name='torchvision_efficientnet_v2_s', - model_fn=tm.efficientnet_v2_s, - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_stochastic_depth_prob=True)) + model_zoo.register( + name="torchvision_efficientnet_v2_s", + model_fn=tm.efficientnet_v2_s, + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_stochastic_depth_prob=True), + ) diff --git a/tests/kit/model_zoo/transformers/albert.py b/tests/kit/model_zoo/transformers/albert.py index 70f9ee11ad6e..d1c23703b3e4 100644 --- a/tests/kit/model_zoo/transformers/albert.py +++ b/tests/kit/model_zoo/transformers/albert.py @@ -19,44 +19,52 @@ def data_gen_fn(): def data_gen_for_pretrain(): inputs = data_gen_fn() - inputs['labels'] = inputs['input_ids'].clone() - inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64) + inputs["labels"] = inputs["input_ids"].clone() + inputs["sentence_order_label"] = torch.zeros(BATCH_SIZE, dtype=torch.int64) return inputs output_transform_fn = lambda x: x -config = transformers.AlbertConfig(embedding_size=128, - hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256) - -model_zoo.register(name='transformers_albert', - model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_pretraining', - model_fn=lambda: transformers.AlbertForPreTraining(config), - data_gen_fn=data_gen_for_pretrain, - output_transform_fn=lambda x: dict(loss=x.loss), - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_masked_lm', - model_fn=lambda: transformers.AlbertForMaskedLM(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_sequence_classification', - model_fn=lambda: transformers.AlbertForSequenceClassification(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_token_classification', - model_fn=lambda: transformers.AlbertForTokenClassification(config), - data_gen_fn=data_gen_fn, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +config = transformers.AlbertConfig( + embedding_size=128, hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256 +) + +model_zoo.register( + name="transformers_albert", + model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_pretraining", + model_fn=lambda: transformers.AlbertForPreTraining(config), + data_gen_fn=data_gen_for_pretrain, + output_transform_fn=lambda x: dict(loss=x.loss), + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_masked_lm", + model_fn=lambda: transformers.AlbertForMaskedLM(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_sequence_classification", + model_fn=lambda: transformers.AlbertForSequenceClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_token_classification", + model_fn=lambda: transformers.AlbertForTokenClassification(config), + data_gen_fn=data_gen_fn, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) # =============================== # Register multi-sentence ALBERT @@ -80,13 +88,17 @@ def data_gen_for_mcq(): return encoding -model_zoo.register(name='transformers_albert_for_question_answering', - model_fn=lambda: transformers.AlbertForQuestionAnswering(config), - data_gen_fn=data_gen_for_qa, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_albert_for_multiple_choice', - model_fn=lambda: transformers.AlbertForMultipleChoice(config), - data_gen_fn=data_gen_for_mcq, - output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_albert_for_question_answering", + model_fn=lambda: transformers.AlbertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_albert_for_multiple_choice", + model_fn=lambda: transformers.AlbertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 993c90b0abc2..8b90a3c7372c 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -28,7 +28,7 @@ def data_gen_for_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -36,7 +36,7 @@ def data_gen_for_pretraining(): # pretraining data gen # `next_sentence_label` is the label for next sentence prediction, 0 or 1 data = data_gen_for_lm() - data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64) + data["next_sentence_label"] = torch.tensor([1], dtype=torch.int64) return data @@ -44,7 +44,7 @@ def data_gen_for_sequence_classification(): # sequence classification data gen # `labels` is the label for sequence classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([1], dtype=torch.int64) + data["labels"] = torch.tensor([1], dtype=torch.int64) return data @@ -52,7 +52,7 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) return data @@ -67,32 +67,276 @@ def data_gen_for_mcq(): # data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) # data = {k: v.unsqueeze(0) for k, v in encoding.items()} # data['labels'] = torch.tensor([0], dtype=torch.int64) - input_ids = torch.tensor([[[ - 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, - 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102, 5442, - 1012, 102, 102 - ], - [ - 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, - 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, - 2218, 1999, 1996, 2192, 1012, 102, 0, 0, 1012, 102, 0, 0 - ]]]) - token_type_ids = torch.tensor([[[ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1 - ], - [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0 - ]]]) - attention_mask = torch.tensor([[[ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1 - ], - [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0 - ]]]) + input_ids = torch.tensor( + [ + [ + [ + 101, + 1999, + 3304, + 1010, + 10733, + 2366, + 1999, + 5337, + 10906, + 1010, + 2107, + 2004, + 2012, + 1037, + 4825, + 1010, + 2003, + 3591, + 4895, + 14540, + 6610, + 2094, + 1012, + 102, + 2009, + 2003, + 8828, + 2007, + 1037, + 9292, + 1998, + 1037, + 5442, + 1012, + 102, + 102, + 5442, + 1012, + 102, + 102, + ], + [ + 101, + 1999, + 3304, + 1010, + 10733, + 2366, + 1999, + 5337, + 10906, + 1010, + 2107, + 2004, + 2012, + 1037, + 4825, + 1010, + 2003, + 3591, + 4895, + 14540, + 6610, + 2094, + 1012, + 102, + 2009, + 2003, + 8828, + 2096, + 2218, + 1999, + 1996, + 2192, + 1012, + 102, + 0, + 0, + 1012, + 102, + 0, + 0, + ], + ] + ] + ) + token_type_ids = torch.tensor( + [ + [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 1, + 1, + 0, + 0, + ], + ] + ] + ) + attention_mask = torch.tensor( + [ + [ + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 1, + 1, + 0, + 0, + ], + ] + ] + ) labels = torch.tensor([0], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) @@ -103,9 +347,9 @@ def data_gen_for_qa(): # no need for labels and use start and end position instead data = data_gen() start_positions = torch.tensor([0], dtype=torch.int64) - data['start_positions'] = start_positions + data["start_positions"] = start_positions end_positions = torch.tensor([1], dtype=torch.int64) - data['end_positions'] = end_positions + data["end_positions"] = end_positions return data @@ -114,69 +358,90 @@ def data_gen_for_qa(): # define loss funciton -loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state - )) +loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn = lambda x: x.loss -config = transformers.BertConfig(hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256, - hidden_dropout_prob=0, - attention_probs_dropout_prob=0) +config = transformers.BertConfig( + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + hidden_dropout_prob=0, + attention_probs_dropout_prob=0, +) # register the BERT variants -model_zoo.register(name='transformers_bert', - model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_bert_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_pretraining', - model_fn=lambda: transformers.BertForPreTraining(config), - data_gen_fn=data_gen_for_pretraining, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_lm_head_model', - model_fn=lambda: transformers.BertLMHeadModel(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_masked_lm', - model_fn=lambda: transformers.BertForMaskedLM(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_sequence_classification', - model_fn=lambda: transformers.BertForSequenceClassification(config), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_token_classification', - model_fn=lambda: transformers.BertForTokenClassification(config), - data_gen_fn=data_gen_for_token_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_next_sentence', - model_fn=lambda: transformers.BertForNextSentencePrediction(config), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_mcq', - model_fn=lambda: transformers.BertForMultipleChoice(config), - data_gen_fn=data_gen_for_mcq, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bert_for_question_answering', - model_fn=lambda: transformers.BertForQuestionAnswering(config), - data_gen_fn=data_gen_for_qa, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_bert", + model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bert_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_pretraining", + model_fn=lambda: transformers.BertForPreTraining(config), + data_gen_fn=data_gen_for_pretraining, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_lm_head_model", + model_fn=lambda: transformers.BertLMHeadModel(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_masked_lm", + model_fn=lambda: transformers.BertForMaskedLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_sequence_classification", + model_fn=lambda: transformers.BertForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_token_classification", + model_fn=lambda: transformers.BertForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_next_sentence", + model_fn=lambda: transformers.BertForNextSentencePrediction(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_mcq", + model_fn=lambda: transformers.BertForMultipleChoice(config), + data_gen_fn=data_gen_for_mcq, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bert_for_question_answering", + model_fn=lambda: transformers.BertForQuestionAnswering(config), + data_gen_fn=data_gen_for_qa, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py index 984a6ffa920d..887b11c7f54e 100644 --- a/tests/kit/model_zoo/transformers/blip2.py +++ b/tests/kit/model_zoo/transformers/blip2.py @@ -47,16 +47,20 @@ def data_gen(): config.text_config.dropout = 0 # register the blip2 variants -model_zoo.register(name='transformers_blip2', - model_fn=lambda: transformers.Blip2Model(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_blip2_model, - model_attribute=ModelAttribute(has_control_flow=True)) - -model_zoo.register(name='transformers_blip2_conditional_gerneration', - model_fn=lambda: transformers.Blip2ForConditionalGeneration(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_blip2_model, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_blip2", + model_fn=lambda: transformers.Blip2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_blip2_conditional_gerneration", + model_fn=lambda: transformers.Blip2ForConditionalGeneration(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_blip2_model, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py index 2d9c882089cb..12dcd71d5d1b 100644 --- a/tests/kit/model_zoo/transformers/bloom.py +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -25,7 +25,7 @@ def data_gen_for_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -33,14 +33,14 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) return data def data_gen_for_sequence_classification(): # sequence classification data gen data = data_gen() - data['labels'] = torch.tensor([0], dtype=torch.int64) + data["labels"] = torch.tensor([0], dtype=torch.int64) return data @@ -54,62 +54,69 @@ def data_gen_for_question_answering(): input_ids = torch.tensor( [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], - dtype=torch.int64) + dtype=torch.int64, + ) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) start_positions = torch.tensor([1], dtype=torch.int64) end_positions = torch.tensor([10], dtype=torch.int64) - return dict(input_ids=input_ids, - attention_mask=attention_mask, - start_positions=start_positions, - end_positions=end_positions) + return dict( + input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions + ) # define output transform function output_transform_fn = lambda x: x # define loss function -loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, - torch.ones_like(x.last_hidden_state)) +loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn_for_causal_lm = lambda x: x.loss loss_fn_for_classification = lambda x: x.loss loss_fn_for_question_answering = lambda x: x.loss -config = transformers.BloomConfig(n_layer=2, - n_head=4, - vocab_size=250880, - hidden_dropout=0, - attention_dropout=0, - hidden_size=64, - pad_token_id=50256) +config = transformers.BloomConfig( + n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256 +) # register the following models -model_zoo.register(name='transformers_bloom', - model_fn=lambda: transformers.BloomModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_bloom_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bloom_for_causal_lm', - model_fn=lambda: transformers.BloomForCausalLM(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_causal_lm, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bloom_for_sequence_classification', - model_fn=lambda: transformers.BloomForSequenceClassification(config), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_classification, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bloom_for_token_classification', - model_fn=lambda: transformers.BloomForTokenClassification(config), - data_gen_fn=data_gen_for_token_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_classification, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_bloom_for_question_answering', - model_fn=lambda: transformers.BloomForQuestionAnswering(config), - data_gen_fn=data_gen_for_question_answering, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_question_answering, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_bloom", + model_fn=lambda: transformers.BloomModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bloom_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_causal_lm", + model_fn=lambda: transformers.BloomForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_sequence_classification", + model_fn=lambda: transformers.BloomForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_token_classification", + model_fn=lambda: transformers.BloomForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_bloom_for_question_answering", + model_fn=lambda: transformers.BloomForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_question_answering, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index d543df00bdfa..22885bec224a 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -1,5 +1,4 @@ import torch -import transformers from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel @@ -21,8 +20,8 @@ def data_gen_for_conditional_generation(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - labels = data['input_ids'].clone() - data['labels'] = labels + labels = data["input_ids"].clone() + data["labels"] = labels return data @@ -30,29 +29,36 @@ def data_gen_for_conditional_generation(): output_transform_fn = lambda x: x # define loss function -loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, - torch.ones_like(x.last_hidden_state)) +loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn = lambda x: x.loss -config = ChatGLMConfig(num_layers=2, - padded_vocab_size=65024, - hidden_size=64, - num_attention_heads=8, - rmsnorm=True, - original_rope=True, - use_cache=True, - torch_dtype=torch.float32) - -model_zoo.register(name='transformers_chatglm', - model_fn=lambda: ChatGLMModel(config, empty_init=False), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_chatglm_model, - model_attribute=ModelAttribute(has_control_flow=True)) - -model_zoo.register(name="transformers_chatglm_for_conditional_generation", - model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), - data_gen_fn=data_gen_for_conditional_generation, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +config = ChatGLMConfig( + num_layers=2, + padded_vocab_size=65024, + hidden_size=64, + num_attention_heads=8, + rmsnorm=True, + original_rope=True, + use_cache=True, + torch_dtype=torch.float32, +) + +model_zoo.register( + name="transformers_chatglm", + model_fn=lambda: ChatGLMModel(config, empty_init=False), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_chatglm_model, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_chatglm_for_conditional_generation", + model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 0198e04689ea..2af6176fbe4a 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -27,7 +27,7 @@ def data_gen_for_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data['labels'] = data['input_ids'].clone() + data["labels"] = data["input_ids"].clone() return data @@ -36,9 +36,9 @@ def data_gen_for_question_answering(): # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() start_positions = torch.tensor([0], dtype=torch.int64) - data['start_positions'] = start_positions + data["start_positions"] = start_positions end_positions = torch.tensor([1], dtype=torch.int64) - data['end_positions'] = end_positions + data["end_positions"] = end_positions return data @@ -46,14 +46,14 @@ def data_gen_for_token_classification(): # token classification data gen # `labels` is the type not the token id for token classification, 0 or 1 data = data_gen() - data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) + data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64) return data def data_gen_for_sequence_classification(): # sequence classification data gen data = data_gen() - data['labels'] = torch.tensor([1], dtype=torch.int64) + data["labels"] = torch.tensor([1], dtype=torch.int64) return data @@ -62,7 +62,8 @@ def date_gen_for_double_heads(): batch_size = 2 input_ids = torch.tensor( [[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]], - dtype=torch.int64) + dtype=torch.int64, + ) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64) @@ -85,58 +86,73 @@ def date_gen_for_double_heads(): output_transform_fn = lambda x: x # define loss function -loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state - )) +loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn = lambda x: x.loss -config = transformers.GPT2Config(n_layer=2, - n_head=4, - vocab_size=50258, - attn_pdrop=0, - embd_pdrop=0, - resid_pdrop=0, - summary_first_dropout=0, - hidden_dropout=0, - problem_type="single_label_classification", - pad_token_id=50256) +config = transformers.GPT2Config( + n_layer=2, + n_head=4, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification", + pad_token_id=50256, +) config_for_token_classification = copy.deepcopy(config) config_for_token_classification.num_labels = 2 # register the following models -model_zoo.register(name='transformers_gpt', - model_fn=lambda: transformers.GPT2Model(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_gpt2_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_lm', - model_fn=lambda: transformers.GPT2LMHeadModel(config), - data_gen_fn=data_gen_for_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_double_heads', - model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), - data_gen_fn=date_gen_for_double_heads, - output_transform_fn=output_transform_fn, - loss_fn=lambda x: x.loss + x.mc_loss, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_for_question_answering', - model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), - data_gen_fn=data_gen_for_question_answering, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_for_token_classification', - model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), - data_gen_fn=data_gen_for_token_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_gpt_for_sequence_classification', - model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), - data_gen_fn=data_gen_for_sequence_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_gpt", + model_fn=lambda: transformers.GPT2Model(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_gpt2_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_lm", + model_fn=lambda: transformers.GPT2LMHeadModel(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_double_heads", + model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), + data_gen_fn=date_gen_for_double_heads, + output_transform_fn=output_transform_fn, + loss_fn=lambda x: x.loss + x.mc_loss, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_question_answering", + model_fn=lambda: transformers.GPT2ForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_token_classification", + model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_gpt_for_sequence_classification", + model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 2018f3b4f440..bc229b17e08c 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -4,7 +4,8 @@ from ..registry import ModelAttribute, model_zoo try: - from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel + from transformers import LlamaConfig + HAS_LLAMA = True except ImportError: HAS_LLAMA = False @@ -33,8 +34,8 @@ def data_gen(): # label is needed for casual lm def data_gen_for_casual_lm(): data = data_gen() - labels = data['input_ids'].clone() - data['labels'] = labels + labels = data["input_ids"].clone() + data["labels"] = labels return data # transform the output to a dict @@ -45,12 +46,14 @@ def data_gen_for_casual_lm(): loss_fn_for_casual_lm = lambda output: output.loss loss_fn_for_seq_classification = lambda output: output.logits.mean() - config = LlamaConfig(num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, - num_attention_heads=4, - max_position_embeddings=128, - num_labels=16) + config = LlamaConfig( + num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16, + ) if hasattr(config, "pad_token_id"): config.pad_token_id = config.eos_token_id @@ -59,21 +62,27 @@ def data_gen_for_casual_lm(): # transformers.LlamaModel, # transformers.LlamaForCausalLM, # transformers.LlamaForSequenceClassification, - model_zoo.register(name='transformers_llama', - model_fn=lambda: transformers.LlamaModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) - model_zoo.register(name='transformers_llama_for_casual_lm', - model_fn=lambda: transformers.LlamaForCausalLM(config), - data_gen_fn=data_gen_for_casual_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_casual_lm, - model_attribute=ModelAttribute(has_control_flow=True)) - model_zoo.register(name='transformers_llama_for_sequence_classification', - model_fn=lambda: transformers.LlamaForSequenceClassification(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_seq_classification, - model_attribute=ModelAttribute(has_control_flow=True)) + model_zoo.register( + name="transformers_llama", + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_llama_for_casual_lm", + model_fn=lambda: transformers.LlamaForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True), + ) + model_zoo.register( + name="transformers_llama_for_sequence_classification", + model_fn=lambda: transformers.LlamaForSequenceClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True), + ) diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index a258e12ac127..07ca41ef21ae 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -20,8 +20,8 @@ def data_gen_for_causal_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - labels = data['input_ids'].clone() - data['labels'] = labels + labels = data["input_ids"].clone() + data["labels"] = labels return data @@ -29,8 +29,8 @@ def data_gen_for_sequence_classification(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - labels = data['input_ids'].clone() - data['labels'] = torch.tensor([1]) + data["input_ids"].clone() + data["labels"] = torch.tensor([1]) return data @@ -38,14 +38,15 @@ def data_gen_for_question_answering(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` data = data_gen() - data['start_positions'] = torch.tensor([0]) - data['end_positions'] = torch.tensor([1]) + data["start_positions"] = torch.tensor([0]) + data["end_positions"] = torch.tensor([1]) return data output_transform_fn = lambda x: x -loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state) - ) +loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss( + x.last_hidden_state, torch.ones_like(x.last_hidden_state) +) loss_fn_for_lm = lambda x: x.loss config = transformers.OPTConfig( hidden_size=128, @@ -57,24 +58,30 @@ def data_gen_for_question_answering(): # register the following models # transformers.OPTModel, # transformers.OPTForCausalLM, -model_zoo.register(name='transformers_opt', - model_fn=lambda: transformers.OPTModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_opt_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_opt_for_causal_lm', - model_fn=lambda: transformers.OPTForCausalLM(config), - data_gen_fn=data_gen_for_causal_lm, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_lm, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_opt_for_question_answering', - model_fn=lambda: transformers.OPTForQuestionAnswering(config), - data_gen_fn=data_gen_for_question_answering, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_lm, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_opt", + model_fn=lambda: transformers.OPTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_opt_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_opt_for_causal_lm", + model_fn=lambda: transformers.OPTForCausalLM(config), + data_gen_fn=data_gen_for_causal_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_opt_for_question_answering", + model_fn=lambda: transformers.OPTForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True), +) # TODO The loss and gradient check in the test are failing, to be fixed. # model_zoo.register(name='transformers_opt_for_sequence_classification', diff --git a/tests/kit/model_zoo/transformers/sam.py b/tests/kit/model_zoo/transformers/sam.py index d850623f368f..b928a8f14e75 100644 --- a/tests/kit/model_zoo/transformers/sam.py +++ b/tests/kit/model_zoo/transformers/sam.py @@ -28,10 +28,12 @@ def data_gen(): original_sizes = torch.tensor([[1764, 2646]], dtype=torch.int64) reshaped_input_sizes = torch.tensor([[683, 1024]], dtype=torch.int64) input_points = torch.tensor([[[[174.1497, 232.3129]]]], dtype=torch.float64) - return dict(pixel_values=pixel_values, - original_sizes=original_sizes, - reshaped_input_sizes=reshaped_input_sizes, - input_points=input_points) + return dict( + pixel_values=pixel_values, + original_sizes=original_sizes, + reshaped_input_sizes=reshaped_input_sizes, + input_points=input_points, + ) # define output transform function @@ -44,9 +46,11 @@ def data_gen(): config.vision_config.num_hidden_layers = 2 # register the BERT variants -model_zoo.register(name='transformers_sam', - model_fn=lambda: transformers.SamModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_sam", + model_fn=lambda: transformers.SamModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 16a594f3950a..1b63cccc42ee 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -27,7 +27,7 @@ def data_gen_for_conditional_generation(): # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids data = data_gen_for_encoder_only() labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long() - data['labels'] = labels + data["labels"] = labels return data @@ -36,7 +36,7 @@ def data_gen_for_t5_model(): # decoder_input_ids = model._shift_right(input_ids) data = data_gen_for_encoder_only() decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long() - data['decoder_input_ids'] = decoder_input_ids + data["decoder_input_ids"] = decoder_input_ids return data @@ -55,21 +55,27 @@ def data_gen_for_t5_model(): # transformers.T5Model, # transformers.T5ForConditionalGeneration, # transformers.T5EncoderModel, -model_zoo.register(name='transformers_t5', - model_fn=lambda: transformers.T5Model(config), - data_gen_fn=data_gen_for_t5_model, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_t5_model, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_t5_for_conditional_generation', - model_fn=lambda: transformers.T5ForConditionalGeneration(config), - data_gen_fn=data_gen_for_conditional_generation, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_conditional_generation, - model_attribute=ModelAttribute(has_control_flow=True)) -model_zoo.register(name='transformers_t5_encoder_model', - model_fn=lambda: transformers.T5EncoderModel(config), - data_gen_fn=data_gen_for_encoder_only, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_encoder_only, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_t5", + model_fn=lambda: transformers.T5Model(config), + data_gen_fn=data_gen_for_t5_model, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_t5_model, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_t5_for_conditional_generation", + model_fn=lambda: transformers.T5ForConditionalGeneration(config), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_conditional_generation, + model_attribute=ModelAttribute(has_control_flow=True), +) +model_zoo.register( + name="transformers_t5_encoder_model", + model_fn=lambda: transformers.T5EncoderModel(config), + data_gen_fn=data_gen_for_encoder_only, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_encoder_only, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py index a84b8d31c284..f1990751b016 100644 --- a/tests/kit/model_zoo/transformers/vit.py +++ b/tests/kit/model_zoo/transformers/vit.py @@ -18,15 +18,15 @@ def data_gen(): def data_gen_for_image_classification(): data = data_gen() - data['labels'] = torch.tensor([0]) + data["labels"] = torch.tensor([0]) return data def data_gen_for_masked_image_modeling(): data = data_gen() - num_patches = (config.image_size // config.patch_size)**2 + num_patches = (config.image_size // config.patch_size) ** 2 bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() - data['bool_masked_pos'] = bool_masked_pos + data["bool_masked_pos"] = bool_masked_pos return data @@ -42,23 +42,29 @@ def data_gen_for_masked_image_modeling(): # transformers.ViTModel, # transformers.ViTForMaskedImageModeling, # transformers.ViTForImageClassification, -model_zoo.register(name='transformers_vit', - model_fn=lambda: transformers.ViTModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_vit_model, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_vit", + model_fn=lambda: transformers.ViTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_vit_model, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='transformers_vit_for_masked_image_modeling', - model_fn=lambda: transformers.ViTForMaskedImageModeling(config), - data_gen_fn=data_gen_for_masked_image_modeling, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_masked_image_modeling, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_vit_for_masked_image_modeling", + model_fn=lambda: transformers.ViTForMaskedImageModeling(config), + data_gen_fn=data_gen_for_masked_image_modeling, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_masked_image_modeling, + model_attribute=ModelAttribute(has_control_flow=True), +) -model_zoo.register(name='transformers_vit_for_image_classification', - model_fn=lambda: transformers.ViTForImageClassification(config), - data_gen_fn=data_gen_for_image_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_image_classification, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_vit_for_image_classification", + model_fn=lambda: transformers.ViTForImageClassification(config), + data_gen_fn=data_gen_for_image_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_image_classification, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index f7cdc052aaf0..928be4468c01 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -33,7 +33,7 @@ def data_gen_for_conditional_generation(): # or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is # only computed for the tokens with labels in `[0, ..., config.vocab_size]`. data = data_gen() - data['labels'] = torch.tensor([[0, 1]], dtype=torch.int64) + data["labels"] = torch.tensor([[0, 1]], dtype=torch.int64) return data @@ -44,8 +44,8 @@ def data_gen_for_audio_classification(): # `config.num_labels > 1` a classification loss is computed (Cross-Entropy). # `WhisperForAudioClassification` does not need `decoder_input_ids` data = data_gen() - data.pop('decoder_input_ids') - data['labels'] = torch.tensor([1], dtype=torch.int64) + data.pop("decoder_input_ids") + data["labels"] = torch.tensor([1], dtype=torch.int64) return data @@ -69,23 +69,29 @@ def data_gen_for_audio_classification(): ) # register the Whisper variants -model_zoo.register(name='transformers_whisper', - model_fn=lambda: transformers.WhisperModel(config), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn, - model_attribute=ModelAttribute(has_control_flow=True)) - -model_zoo.register(name='transformers_whisper_for_conditional_generation', - model_fn=lambda: transformers.WhisperForConditionalGeneration(config), - data_gen_fn=data_gen_for_conditional_generation, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_attr, - model_attribute=ModelAttribute(has_control_flow=True)) - -model_zoo.register(name='transformers_whisper_for_audio_classification', - model_fn=lambda: transformers.WhisperForAudioClassification(config), - data_gen_fn=data_gen_for_audio_classification, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_attr, - model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register( + name="transformers_whisper", + model_fn=lambda: transformers.WhisperModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_whisper_for_conditional_generation", + model_fn=lambda: transformers.WhisperForConditionalGeneration(config), + data_gen_fn=data_gen_for_conditional_generation, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True), +) + +model_zoo.register( + name="transformers_whisper_for_audio_classification", + model_fn=lambda: transformers.WhisperForAudioClassification(config), + data_gen_fn=data_gen_for_audio_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_attr, + model_attribute=ModelAttribute(has_control_flow=True), +) diff --git a/tests/test_analyzer/test_fx/test_bias_addition.py b/tests/test_analyzer/test_fx/test_bias_addition.py index f7b5eb140f24..f72c1cb3f533 100644 --- a/tests/test_analyzer/test_fx/test_bias_addition.py +++ b/tests/test_analyzer/test_fx/test_bias_addition.py @@ -12,7 +12,6 @@ class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) @@ -23,25 +22,14 @@ def forward(self, x): class ConvModel(torch.nn.Module): - def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: super().__init__() - self.conv = torch.nn.Conv2d(in_channel, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) - self.conv_transpose = torch.nn.ConvTranspose2d(in_channel, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) + self.conv = torch.nn.Conv2d( + in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) + self.conv_transpose = torch.nn.ConvTranspose2d( + in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) def forward(self, x, select=0): if select == 0: @@ -52,7 +40,6 @@ def forward(self, x, select=0): class SiuModel(torch.nn.Module): - def __init__(self, bias) -> None: super().__init__() self.linear = LinearModel(3, 3, bias) @@ -69,7 +56,6 @@ def forward(self, x, select=torch.Tensor([0])): class AddmmModel(torch.nn.Module): - def __init__(self, alpha, beta) -> None: super().__init__() self.alpha = alpha @@ -80,7 +66,7 @@ def forward(self, x): return x -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @@ -89,19 +75,21 @@ def forward(self, x): def test_siu_model(bias, bias_addition_split, shape, select): model = SiuModel(bias=bias) x = torch.rand(shape) - gm = symbolic_trace(model, - meta_args={'x': x}, - concrete_args={'select': select}, - trace_act_ckpt=True, - bias_addition_split=bias_addition_split) - assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!' + gm = symbolic_trace( + model, + meta_args={"x": x}, + concrete_args={"select": select}, + trace_act_ckpt=True, + bias_addition_split=bias_addition_split, + ) + assert torch.allclose(model(x, select), gm(x)), "original model and traced model should be the same!" if bias and bias_addition_split: - assert '+' in gm.code, 'bias addition should be split!' + assert "+" in gm.code, "bias addition should be split!" else: - assert '+' not in gm.code, 'bias addition should not be split!' + assert "+" not in gm.code, "bias addition should not be split!" -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @parameterize("alpha", [1, 2]) @parameterize("beta", [1, 2]) @parameterize("bias_addition_split", [True, False]) @@ -109,14 +97,14 @@ def test_siu_model(bias, bias_addition_split, shape, select): def test_addmm_model(alpha, beta, bias_addition_split, shape): model = AddmmModel(alpha=alpha, beta=beta) x = torch.rand(shape) - gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) - assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!' + gm = symbolic_trace(model, meta_args={"x": x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split) + assert torch.allclose(model(x), gm(x)), "original model and traced model should be the same!" if (alpha == 1 and beta == 1) or not bias_addition_split: - assert '*' not in gm.code, 'bias addition should not be split!' + assert "*" not in gm.code, "bias addition should not be split!" elif bias_addition_split: - assert '+' in gm.code, 'bias addition should be split!' + assert "+" in gm.code, "bias addition should be split!" -if __name__ == '__main__': +if __name__ == "__main__": test_siu_model() test_addmm_model() diff --git a/tests/test_analyzer/test_fx/test_mod_dir.py b/tests/test_analyzer/test_fx/test_mod_dir.py index f62147b297a2..be151b1edd80 100644 --- a/tests/test_analyzer/test_fx/test_mod_dir.py +++ b/tests/test_analyzer/test_fx/test_mod_dir.py @@ -10,7 +10,6 @@ class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) @@ -21,25 +20,14 @@ def forward(self, x): class ConvModel(torch.nn.Module): - def __init__(self, in_channel, out_channels, kernel_size, bias) -> None: super().__init__() - self.conv = torch.nn.Conv2d(in_channel, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) - self.conv_transpose = torch.nn.ConvTranspose2d(out_channels, - out_channels, - kernel_size, - bias=bias, - padding=1, - stride=2, - dilation=2, - groups=3) + self.conv = torch.nn.Conv2d( + in_channel, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) + self.conv_transpose = torch.nn.ConvTranspose2d( + out_channels, out_channels, kernel_size, bias=bias, padding=1, stride=2, dilation=2, groups=3 + ) def forward(self, x): x = self.conv(x) @@ -48,7 +36,6 @@ def forward(self, x): class AModel(torch.nn.Module): - def __init__(self, bias) -> None: super().__init__() self.linear_1 = LinearModel(3, 3, bias) @@ -63,7 +50,7 @@ def forward(self, x): return x -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12") @clear_cache_before_run() @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @@ -71,11 +58,11 @@ def forward(self, x): def test_mod_dir(bias, bias_addition_split, shape): model = AModel(bias=bias) x = torch.rand(shape) - gm = symbolic_trace(model, meta_args={'x': x}, bias_addition_split=bias_addition_split) + gm = symbolic_trace(model, meta_args={"x": x}, bias_addition_split=bias_addition_split) for node in gm.graph.nodes: - assert len(node.meta['info'].mod_dir), f"{node} should have non-trivial ``mod_dir``." - print(node, node.meta['info'].mod_dir) + assert len(node.meta["info"].mod_dir), f"{node} should have non-trivial ``mod_dir``." + print(node, node.meta["info"].mod_dir) -if __name__ == '__main__': +if __name__ == "__main__": test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3)) diff --git a/tests/test_analyzer/test_fx/test_nested_ckpt.py b/tests/test_analyzer/test_fx/test_nested_ckpt.py index bd16f5a4f95d..d7b96fb9f043 100644 --- a/tests/test_analyzer/test_fx/test_nested_ckpt.py +++ b/tests/test_analyzer/test_fx/test_nested_ckpt.py @@ -12,7 +12,6 @@ class MyModule(nn.Module): - def __init__(self): super().__init__() self.a = nn.Linear(10, 10) @@ -43,14 +42,14 @@ def forward(self, x): return checkpoint(self.checkpoint_0, x) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12") @clear_cache_before_run() def test_nested_ckpt(): model = MyModule() x = torch.rand(10, 10) - gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True) + gm = symbolic_trace(model, meta_args={"x": x}, trace_act_ckpt=True) assert torch.allclose(gm(x), model(x)), "The traced model should generate the same output as the original model." - for ckpt_def in filter(lambda s: s.startswith('checkpoint'), dir(model)): + for ckpt_def in filter(lambda s: s.startswith("checkpoint"), dir(model)): assert ckpt_def in gm.code, f"Checkpoint {ckpt_def} should be in the traced code.\n Traced code = {gm.code}" diff --git a/tests/test_analyzer/test_fx/test_shape_prop.py b/tests/test_analyzer/test_fx/test_shape_prop.py index a849feb795e5..609fc9c7b022 100644 --- a/tests/test_analyzer/test_fx/test_shape_prop.py +++ b/tests/test_analyzer/test_fx/test_shape_prop.py @@ -1,6 +1,5 @@ import pytest import torch -import torchvision.models as tm from packaging import version from colossalai.testing.utils import clear_cache_before_run, parameterize @@ -16,24 +15,25 @@ def linear_impl(*args, **kwargs): assert True return torch.nn.functional.linear(*args, **kwargs) + except: pass def _check_gm_validity(gm: torch.fx.GraphModule): for node in gm.graph.nodes: - assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.' + assert node.meta["info"].outputs, f"In {gm.__class__.__name__}, {node} has no output shape." if node.op in [ - 'call_module', # can apply to params - 'call_function', # can apply to params - 'call_method', # can apply to params + "call_module", # can apply to params + "call_function", # can apply to params + "call_method", # can apply to params ]: - assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.' + assert hasattr(node.meta["info"], "inputs"), f"In {gm.__class__.__name__}, {node} has no input shape." -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tm_models) +@parameterize("m", tm_models) def test_torchvision_shape_prop(m): with MetaTensorMode(): model = m() @@ -46,9 +46,9 @@ def test_torchvision_shape_prop(m): _check_gm_validity(gm) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tmm_models) +@parameterize("m", tmm_models) def test_timm_shape_prop(m): with MetaTensorMode(): model = m() diff --git a/tests/test_analyzer/test_fx/test_symbolic_profile.py b/tests/test_analyzer/test_fx/test_symbolic_profile.py index 17deee7a7118..8d8ee2445d58 100644 --- a/tests/test_analyzer/test_fx/test_symbolic_profile.py +++ b/tests/test_analyzer/test_fx/test_symbolic_profile.py @@ -1,6 +1,5 @@ import pytest import torch -import torchvision.models as tm from packaging import version from colossalai.testing.utils import clear_cache_before_run, parameterize @@ -15,12 +14,12 @@ def _check_gm_validity(gm: torch.fx.GraphModule): for node in gm.graph.nodes: - assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.' + assert len(node.meta["info"].global_ctx), f"In {gm.__class__.__name__}, {node} has empty global context." -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tm_models) +@parameterize("m", tm_models) def test_torchvision_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): model = m() @@ -33,9 +32,9 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False): _check_gm_validity(gm) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tmm_models) +@parameterize("m", tmm_models) def test_timm_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): model = m() diff --git a/tests/test_analyzer/test_subclasses/test_aten.py b/tests/test_analyzer/test_subclasses/test_aten.py index b7858110ac09..61c1d25f7b3d 100644 --- a/tests/test_analyzer/test_subclasses/test_aten.py +++ b/tests/test_analyzer/test_subclasses/test_aten.py @@ -14,35 +14,41 @@ aten = torch.ops.aten registered_meta = { - ('aten.convolution.default', True): [ # (aten ops, requires_backward) + ("aten.convolution.default", True): [ # (aten ops, requires_backward) (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), - (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4)), - (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4, 4)), + ( + nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4), + ), + ( + nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4, 4), + ), ], - ('aten.native_batch_norm.default', True): [ + ("aten.native_batch_norm.default", True): [ (nn.BatchNorm1d(4), torch.rand(2, 4)), (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), ], - ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], - ('aten.avg_pool1d.default', True): [ + ("aten.native_layer_norm.default", True): [ + (nn.LayerNorm(4), torch.rand(1, 2, 3, 4)), + ], + ("aten.avg_pool1d.default", True): [ (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), ], - ('aten.avg_pool2d.default', True): [ + ("aten.avg_pool2d.default", True): [ (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), ], - ('aten.relu.default', True): [ + ("aten.relu.default", True): [ (nn.ReLU(), torch.rand(4, 3, 1, 2)), (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), (nn.SiLU(), torch.rand(4, 3, 1, 2)), @@ -51,15 +57,20 @@ (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), (nn.Tanh(), torch.rand(4, 3, 1, 2)), (nn.Hardswish(), torch.rand(4, 3, 1, 2)), - ] + ], } def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: - assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' - assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' - assert tensor.stride() == meta_tensor.stride( - ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + assert ( + tensor.shape == meta_tensor.shape + ), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match." + assert ( + tensor.dtype == meta_tensor.dtype + ), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match." + assert ( + tensor.stride() == meta_tensor.stride() + ), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match." def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: @@ -73,7 +84,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="torch version < 12") @clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): @@ -81,5 +92,5 @@ def test_meta_aten(): run_and_compare(f, x, requires_backward) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_aten() diff --git a/tests/test_analyzer/test_subclasses/test_flop_tensor.py b/tests/test_analyzer/test_subclasses/test_flop_tensor.py index 4e9c9852649b..b1b9a89fad97 100644 --- a/tests/test_analyzer/test_subclasses/test_flop_tensor.py +++ b/tests/test_analyzer/test_subclasses/test_flop_tensor.py @@ -4,7 +4,6 @@ import torchvision.models as tm from packaging import version -from colossalai.testing import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: @@ -13,40 +12,44 @@ pass -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize('m', tm_models + tmm_models) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") +@pytest.mark.parametrize("m", tm_models + tmm_models) def test_flop_count_module(m): x = torch.rand(2, 3, 224, 224) - with MetaTensorMode(): # save time for testing + with MetaTensorMode(): # save time for testing module = m() rs_fwd, rs_bwd = flop_count(module, x, verbose=True) - assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}' - assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}' + assert rs_fwd > 0, f"fwd flop count of {m.__name__} is {rs_fwd}" + assert rs_bwd > 0, f"bwd flop count of {m.__name__} is {rs_bwd}" odd_cases = [ - (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), { - 'inplace': True - }), - (F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), { - 'kernel_size': 3, - 'stride': 2, - 'padding': 1, - 'dilation': 2 - }), - (torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True), - torch.rand(2, 3, 224, 224, requires_grad=True)), {}), + (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}), + ( + F.max_pool2d, + (torch.rand(2, 3, 224, 224, requires_grad=True),), + {"kernel_size": 3, "stride": 2, "padding": 1, "dilation": 2}, + ), + ( + torch.where, + ( + torch.rand(2, 3, 224, 224) > 0.5, + torch.rand(2, 3, 224, 224, requires_grad=True), + torch.rand(2, 3, 224, 224, requires_grad=True), + ), + {}, + ), ] -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize('func, args, kwargs', odd_cases) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") +@pytest.mark.parametrize("func, args, kwargs", odd_cases) def test_flop_count_function(func, args, kwargs): rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) - assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}' - assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}' + assert rs_fwd > 0, f"fwd flop count of {func.__name__} is {rs_fwd}" + assert rs_bwd > 0, f"bwd flop count of {func.__name__} is {rs_bwd}" -if __name__ == '__main__': +if __name__ == "__main__": test_flop_count_module(tm.resnet18) - test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True}) + test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {"inplace": True}) diff --git a/tests/test_analyzer/test_subclasses/test_meta_mode.py b/tests/test_analyzer/test_subclasses/test_meta_mode.py index d2a0a1b9cfb5..c55c4ec42703 100644 --- a/tests/test_analyzer/test_subclasses/test_meta_mode.py +++ b/tests/test_analyzer/test_subclasses/test_meta_mode.py @@ -6,17 +6,22 @@ from colossalai.testing import clear_cache_before_run, parameterize try: - from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode + from colossalai._analyzer._subclasses import MetaTensorMode except: pass from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor): - assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' - assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' - assert tensor.stride() == meta_tensor.stride( - ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + assert ( + tensor.shape == meta_tensor.shape + ), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match." + assert ( + tensor.dtype == meta_tensor.dtype + ), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match." + assert ( + tensor.stride() == meta_tensor.stride() + ), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match." def run_and_compare(model): @@ -31,12 +36,12 @@ def run_and_compare(model): compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() -@parameterize('m', tm_models + tmm_models) +@parameterize("m", tm_models + tmm_models) def test_meta_mode_shape(m): run_and_compare(m()) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_mode_shape(tm.resnet18) diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py index b65e6d0d8863..03bba8e64772 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py @@ -8,6 +8,7 @@ import colossalai from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta + # from colossalai.fx.passes.algorithms import solver_rotor # from colossalai.fx.passes.algorithms.operation import Sequence from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -19,18 +20,18 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + withcodegen = True except: - from colossalai.fx.codegen import python_code_with_activation_checkpoint withcodegen = False def _run_C_solver_consistency_test(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: model = M() - data = torch.rand(128, 3, 224, 224, device='meta') + data = torch.rand(128, 3, 224, 224, device="meta") tracer = ColoTracer() graph = tracer.trace(model, meta_args={"x": data}) @@ -54,15 +55,17 @@ def _run_C_solver_consistency_test(rank, world_size, port): for m in range(len(opt_python)): for d in range(1, len(opt_python[0])): for i in range(len(opt_python[0]) - d): - assert opt_python[m][i][i + d] == opt_C[m][i][i + d], \ - f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}" + assert ( + opt_python[m][i][i + d] == opt_C[m][i][i + d] + ), f"item ({m}, {i}, {i + d}) is not consistent with python version!\npython version: {opt_python[m][i][i + d]}\nC version: {opt_C[m][i][i + d]}" sequence_python = sequence_python.list_operations() sequence_C = sequence_C.list_operations() # make sure the sequences are the same - assert len(sequence_python) == len(sequence_C) and \ - all(python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C)) + assert len(sequence_python) == len(sequence_C) and all( + python_op.__repr__() == C_op.__repr__() for (python_op, C_op) in zip(sequence_python, sequence_C) + ) gpc.destroy() @@ -74,5 +77,5 @@ def test_C_solver_consistency(): spawn(_run_C_solver_consistency_test, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_C_solver_consistency_test(rank=0) diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py index babdddfada18..c46f57f75303 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py @@ -11,6 +11,7 @@ from colossalai.fx import ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.graph_module import ColoGraphModule + # from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.legacy.core import global_context as gpc @@ -21,10 +22,12 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False # SOLVERS = [chen_greedy, solver_rotor] @@ -33,7 +36,7 @@ def _is_activation_checkpoint_available(gm: GraphModule): for n in gm.graph.nodes: - if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None: + if hasattr(n, "activation_checkpoint") and getattr(n, "activation_checkpoint") is not None: return True @@ -47,15 +50,19 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): def _is_graph_linearized(gm: GraphModule): code = gm.code # find patterns like r' return output_1, output_2', which is not expected on a linearized graph - pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+') + pattern = re.compile(r" return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+") if pattern.findall(code): return False else: return True -def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], - model_cls: Callable[[], torch.nn.Module]): +def check_backward_consistency( + m: torch.nn.Module, + gm: GraphModule, + solver: Callable[[GraphModule], GraphModule], + model_cls: Callable[[], torch.nn.Module], +): criterion = torch.nn.MSELoss() m.cuda() data = torch.rand(2, 3, 32, 32).cuda() @@ -64,18 +71,18 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call loss.backward() loss = criterion(gm(data), label) loss.backward() - assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' + assert _is_all_gradient_close(m, gm), f"Solver {solver} did not work correctly in backward pass on {model_cls}" def _run_ckpt_solver(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(8, 3, 224, 224, device='meta') + data = torch.rand(8, 3, 224, 224, device="meta") for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5) @@ -90,27 +97,28 @@ def _run_ckpt_solver(rank, world_size, port): gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( - gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + gm + ), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) gpc.destroy() @pytest.mark.skip("TODO(super-dainiu): refactor all tests.") -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @rerun_if_address_is_in_use() def test_ckpt_solver(): spawn(_run_ckpt_solver, 1) def _run_ckpt_solver_torch11(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(8, 3, 32, 32, device='meta') + data = torch.rand(8, 3, 32, 32, device="meta") for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5) @@ -124,19 +132,20 @@ def _run_ckpt_solver_torch11(rank, world_size, port): gm = solver(gm) assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( - gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + gm + ), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @rerun_if_address_is_in_use() def test_ckpt_solver_torch11(): spawn(_run_ckpt_solver_torch11, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_ckpt_solver(rank=0) test_ckpt_solver() test_ckpt_solver_torch11() diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py index 59880815dc5e..bb3be9344566 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py @@ -5,6 +5,7 @@ from colossalai.fx import ColoTracer from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.graph_module import ColoGraphModule + # from colossalai.fx.passes.algorithms import linearize, solver_rotor # from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -15,14 +16,16 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False -@pytest.mark.skip(reason='TODO: modify the logger') +@pytest.mark.skip(reason="TODO: modify the logger") @pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @clear_cache_before_run() @@ -35,12 +38,12 @@ def test_linearize(): graph = tracer.trace(model) graph.set_codegen(ActivationCheckpointCodeGen()) gm = ColoGraphModule(model, graph, model.__class__.__name__) - MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device='cpu')) + MetaInfoProp(gm).run(MetaTensor(torch.rand(128, 3, 224, 224, device="meta"), fake_device="cpu")) node_list = linearize(gm) gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) op_list = gm.__sequence__.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - op_list = op_list[:op_list.index(loss_op)] + op_list = op_list[: op_list.index(loss_op)] in_ckpt = False ckpt_idx = 0 for idx, op in enumerate(op_list): @@ -48,8 +51,9 @@ def test_linearize(): if isinstance(op, ForwardNograd): for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint[ - 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + assert ( + n.activation_checkpoint[0] == ckpt_idx + ), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -65,8 +69,9 @@ def test_linearize(): ckpt_idx += 1 for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint[ - 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + assert ( + n.activation_checkpoint[0] == ckpt_idx + ), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" continue @@ -75,8 +80,9 @@ def test_linearize(): in_ckpt = True for n in node_list[idx]: assert hasattr(n, "activation_checkpoint"), f"{n} is not annotated!" - assert n.activation_checkpoint[ - 0] == ckpt_idx, f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" + assert ( + n.activation_checkpoint[0] == ckpt_idx + ), f"{n} ckpt_idx {n.activation_checkpoint[0]} wrong, should be {ckpt_idx}!" del model del gm @@ -100,7 +106,7 @@ def test_linearize_torch11(): gm = solver_rotor(gm, data=torch.rand(128, 3, 224, 224, device="meta"), mem_limit=budget * 1024**2) op_list = gm.__sequence__.list_operations() loss_op = next(op for op in op_list if isinstance(op, Loss)) - op_list = op_list[:op_list.index(loss_op)] + op_list = op_list[: op_list.index(loss_op)] in_ckpt = False ckpt_idx = 0 for idx, op in enumerate(op_list): diff --git a/tests/test_auto_parallel/test_offload/model_utils.py b/tests/test_auto_parallel/test_offload/model_utils.py index c22b17ae42ba..0efe84655aac 100644 --- a/tests/test_auto_parallel/test_offload/model_utils.py +++ b/tests/test_auto_parallel/test_offload/model_utils.py @@ -1,25 +1,23 @@ import torch import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel -from transformers import BertConfig, BertLMHeadModel +from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel + from tests.components_to_test.registry import non_distributed_component_funcs -class GPTLMModel(nn.Module): - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257): +class GPTLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): super().__init__() self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) def forward(self, input_ids, attention_mask): # Only return lm_logits @@ -27,7 +25,6 @@ def forward(self, input_ids, attention_mask): class LMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() @@ -38,18 +35,27 @@ def forward(self, logits, labels): # Flatten the tokens return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + class BertLMModel(nn.Module): def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522): super().__init__() - self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size, - num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size, - vocab_size=vocab_size)) + self.model = BertLMHeadModel( + BertConfig( + n_embd=hidden_size, + num_hidden_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=hidden_size, + vocab_size=vocab_size, + ) + ) def forward(self, input_ids, attention_mask): # Only return lm_logits return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] -@non_distributed_component_funcs.register(name='bert_') + +@non_distributed_component_funcs.register(name="bert_") def get_bert_components(): vocab_size = 1024 seq_len = 64 @@ -67,7 +73,8 @@ def bert_data_gen(device="meta"): return bert_model_builder, bert_data_gen -@non_distributed_component_funcs.register(name='gpt2_') + +@non_distributed_component_funcs.register(name="gpt2_") def get_gpt2_components(): vocab_size = 1024 seq_len = 8 @@ -83,4 +90,4 @@ def gpt2_data_gen(device="meta"): kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs - return gpt2_model_builder, gpt2_data_gen \ No newline at end of file + return gpt2_model_builder, gpt2_data_gen diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 45c22efc4127..2c8b260e6498 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -17,18 +17,22 @@ from tests.test_tensor.common_utils import set_seed -@parameterize('model_name', ['gpt2_']) -@parameterize('memory_budget', [5000]) -@parameterize('solver_name', ['asyn']) +@parameterize("model_name", ["gpt2_"]) +@parameterize("memory_budget", [5000]) +@parameterize("solver_name", ["asyn"]) def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): - # build model get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() - label = torch.randint(low=0, high=128, size=( - 64, - 8, - ), device=get_current_device()) + label = torch.randint( + low=0, + high=128, + size=( + 64, + 8, + ), + device=get_current_device(), + ) criterion = LMLoss() set_seed(42) @@ -50,17 +54,19 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3) optim = AMPOptimizer(hybrid_optimizer, model) - with ColoInitContext(device=torch.device('cpu')): + with ColoInitContext(device=torch.device("cpu")): gemini_model = model_builder() gemini_model.train() hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) - gemini_config = dict(strict_ddp_mode=False, - device=torch.device('cpu'), - placement_policy='cpu', - pin_memory=True, - hidden_dim=8192, - search_range_m=128) + gemini_config = dict( + strict_ddp_mode=False, + device=torch.device("cpu"), + placement_policy="cpu", + pin_memory=True, + hidden_dim=8192, + search_range_m=128, + ) gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config) optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config) @@ -89,9 +95,11 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): exec_time = sum(sorted(time_list)[:5]) / 5 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 - print(f'gemini | model_name: {model_name}') - print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(f"gemini | model_name: {model_name}") + print( + f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB " + f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|" + ) print(time_list) del data_args @@ -124,24 +132,26 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): exec_time = sum(sorted(time_list)[:5]) / 5 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 - print(f'solver_name: {solver_name} | model_name: {model_name}') - print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') + print(f"solver_name: {solver_name} | model_name: {model_name}") + print( + f"| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB " + f"| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|" + ) print(time_list) def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_fwd_bwd() @pytest.mark.skip("this test failed") -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed") @rerun_if_address_is_in_use() def test_perf(): spawn(run_dist, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_perf() diff --git a/tests/test_auto_parallel/test_offload/test_solver.py b/tests/test_auto_parallel/test_offload/test_solver.py index aa2c9a36849f..6bb53aa67495 100644 --- a/tests/test_auto_parallel/test_offload/test_solver.py +++ b/tests/test_auto_parallel/test_offload/test_solver.py @@ -11,13 +11,12 @@ from tests.test_auto_parallel.test_offload.model_utils import * -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@pytest.mark.skipif(NOT_NVML, reason="pynvml is not installed") @clear_cache_before_run() -@parameterize('model_name', ['gpt2_', 'bert_']) -@parameterize('memory_budget', [4000]) -@parameterize('solver_name', ['syn', 'asyn']) +@parameterize("model_name", ["gpt2_", "bert_"]) +@parameterize("memory_budget", [4000]) +@parameterize("solver_name", ["syn", "asyn"]) def solver_test(model_name: str, memory_budget: float, solver_name: str): - get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() data_args = data_gen(device="cpu") @@ -53,15 +52,15 @@ def solver_test(model_name: str, memory_budget: float, solver_name: str): need_offload = region.need_offload to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None print( - f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + f"| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}" ) for region in region_list.__reversed__(): need_offload = region.need_offload to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None print( - f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + f"| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}" ) -if __name__ == '__main__': +if __name__ == "__main__": solver_test() diff --git a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py index 429e89aae5d3..2b89a73656b1 100644 --- a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass from colossalai.device.device_mesh import DeviceMesh @@ -10,7 +9,6 @@ class TestModule(torch.nn.Module): - def forward(self, x): x = x.view(4, 4, 2) return x @@ -19,7 +17,7 @@ def forward(self, x): def insert_narrow(gm, x_node): graph = gm.graph with graph.inserting_after(x_node): - shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + shard_node = graph.create_node("call_method", "narrow", args=(x_node, 0, 0, 2), kwargs={}) view_node = list(x_node.users.keys())[0] new_args = list(view_node.args) new_args[0] = shard_node @@ -33,7 +31,7 @@ def test_node_args_converting_pass(): physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - meta_args = {'x': torch.rand(4, 8).to('meta')} + meta_args = {"x": torch.rand(4, 8).to("meta")} input = torch.rand(4, 8) tracer = ColoTracer() graph = tracer.trace(root=model, meta_args=meta_args) @@ -41,8 +39,8 @@ def test_node_args_converting_pass(): x_node = list(graph.nodes)[0] view_node = list(graph.nodes)[1] sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) - setattr(x_node, 'sharding_spec', sharding_spec) - setattr(view_node, 'sharding_spec', sharding_spec) + setattr(x_node, "sharding_spec", sharding_spec) + setattr(view_node, "sharding_spec", sharding_spec) gm = ColoGraphModule(model, graph) gm = node_args_converting_pass(gm, device_mesh) @@ -52,5 +50,5 @@ def test_node_args_converting_pass(): assert output.shape == torch.Size([2, 4, 2]) -if __name__ == '__main__': +if __name__ == "__main__": test_node_args_converting_pass() diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py index bca81201c6ef..b6cc6c9b44fd 100644 --- a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -1,6 +1,5 @@ import pytest import torch -import torch.nn.functional as F from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass @@ -12,7 +11,6 @@ class TestModule(torch.nn.Module): - def forward(self, x): size = x.size() return size @@ -21,7 +19,7 @@ def forward(self, x): def insert_narrow(gm, x_node): graph = gm.graph with graph.inserting_after(x_node): - shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={}) + shard_node = graph.create_node("call_method", "narrow", args=(x_node, 0, 0, 2), kwargs={}) size_node = list(x_node.users.keys())[0] size_node.args = (shard_node,) return gm @@ -36,20 +34,20 @@ def recover_narrow(gm, narrow_node): return gm -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") @clear_cache_before_run() def test_size_value_converting_pass(): model = TestModule() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - meta_args = {'x': torch.rand(4, 8).to('meta')} + meta_args = {"x": torch.rand(4, 8).to("meta")} input = torch.rand(4, 8) tracer = ColoTracer(bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) x_node = list(graph.nodes)[0] x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) - setattr(x_node, 'sharding_spec', x_sharding_spec) + setattr(x_node, "sharding_spec", x_sharding_spec) gm = ColoGraphModule(model, graph) gm = insert_narrow(gm, x_node) shape_prop_pass(gm, *meta_args.values()) @@ -66,5 +64,5 @@ def test_size_value_converting_pass(): assert size == torch.Size([4, 8]) -if __name__ == '__main__': +if __name__ == "__main__": test_size_value_converting_pass() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py index 9fbe674ef4f4..c41c66745012 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -1,10 +1,9 @@ -from functools import partial - import pytest import torch try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -16,7 +15,6 @@ class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.linear = torch.nn.Linear(in_features, out_features) @@ -29,13 +27,11 @@ def forward(self, x): class ConvModel(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True): super().__init__() - self.conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=bias) + self.conv = torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias + ) def forward(self, x): x = self.conv(x) @@ -46,7 +42,7 @@ def forward(self, x): def check_linear_module(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModel(4, 8).cuda() input = torch.rand(4, 4).cuda() output_compare = model(input) @@ -55,7 +51,7 @@ def check_linear_module(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4).to('meta')} + meta_args = {"x": torch.rand(4, 4).to("meta")} gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh) output = gm(input) assert_close(output, output_compare) @@ -63,7 +59,7 @@ def check_linear_module(rank, world_size, port): def check_conv_module(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvModel(3, 6, 2).cuda() input = torch.rand(4, 3, 64, 64).cuda() output_compare = model(input) @@ -72,14 +68,14 @@ def check_conv_module(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 3, 64, 64).to('meta')} + meta_args = {"x": torch.rand(4, 3, 64, 64).to("meta")} gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh) output = gm(input) assert_close(output, output_compare) -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_bias_addition_module(): @@ -87,5 +83,5 @@ def test_bias_addition_module(): spawn(check_conv_module, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_bias_addition_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py index 5607587496f3..5cc1820837bb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py @@ -48,17 +48,15 @@ def test_recover_sharding_spec_for_broadcast_shape(): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) broadcast_shape = get_broadcast_shape(x1.shape, x2.shape) - logical_sharding_spec_for_x1 = ShardingSpec(device_mesh=device_mesh, - dim_partition_dict={ - 0: [0], - 1: [1] - }, - entire_shape=broadcast_shape) + logical_sharding_spec_for_x1 = ShardingSpec( + device_mesh=device_mesh, dim_partition_dict={0: [0], 1: [1]}, entire_shape=broadcast_shape + ) physical_sharding_spec_for_x1, removed_dims = recover_sharding_spec_for_broadcast_shape( - logical_sharding_spec_for_x1, broadcast_shape, x1.shape) + logical_sharding_spec_for_x1, broadcast_shape, x1.shape + ) print(physical_sharding_spec_for_x1) assert physical_sharding_spec_for_x1.entire_shape == x1.shape # dim 1 for the physical tensor is of broadcast type MULTIPLE, so should ignore assert physical_sharding_spec_for_x1.dim_partition_dict == {0: [0]} - assert physical_sharding_spec_for_x1.sharding_sequence == ['S0', 'R', 'R'] + assert physical_sharding_spec_for_x1.sharding_sequence == ["S0", "R", "R"] diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py index 398458306e3d..c800f54da66c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -8,6 +8,7 @@ try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -21,7 +22,6 @@ class GPT2MLPWithCkpt(nn.Module): - def __init__(self, intermediate_size, hidden_size): super().__init__() embed_dim = hidden_size @@ -39,11 +39,11 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl def check_act_ckpt(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) - input = torch.rand(1, 64, HIDDEN_SIZE) + torch.rand(1, 64, HIDDEN_SIZE) input_sample = { - 'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'), + "hidden_states": torch.rand(1, 64, HIDDEN_SIZE).to("meta"), } physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -51,18 +51,24 @@ def check_act_ckpt(rank, world_size, port): # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) gm = initialize_model(model, input_sample, device_mesh) - code = gm.module.graph.python_code('self').src - assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code - assert "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" in code + code = gm.module.graph.python_code("self").src + assert ( + "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" + in code + ) + assert ( + "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" + in code + ) -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_mlp_layer(): spawn(check_act_ckpt, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py index 6908a1781869..e8f175326bb1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -6,6 +6,7 @@ try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -17,7 +18,6 @@ class MLP(torch.nn.Module): - def __init__(self, in_features): super().__init__() self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) @@ -32,7 +32,7 @@ def forward(self, x): def check_compatibility_with_ddp(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MLP(4).cuda() if rank in [0, 1]: input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda() @@ -49,26 +49,28 @@ def check_compatibility_with_ddp(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4).to('meta')} - gm, solution = initialize_model(model, - meta_args=meta_args, - device_mesh=device_mesh, - return_solution=True, - solver_preference='tp', - shard_option='shard_last_axis') - - msg = '| TP strategy combination chosen by auto-parallel solver |' + meta_args = {"x": torch.rand(4, 4).to("meta")} + gm, solution = initialize_model( + model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference="tp", + shard_option="shard_last_axis", + ) + + msg = "| TP strategy combination chosen by auto-parallel solver |" msg_length = len(msg) if rank == 0: - print('=' * msg_length) + print("=" * msg_length) print(msg) - print('=' * msg_length) + print("=" * msg_length) for strategy in solution: print(strategy) - print('=' * msg_length) + print("=" * msg_length) dp_process_group = None - for (ranks, process_group_handle) in device_mesh.process_groups_dict[0]: + for ranks, process_group_handle in device_mesh.process_groups_dict[0]: if rank in ranks: dp_process_group = process_group_handle assert dp_process_group is not None @@ -79,7 +81,7 @@ def check_compatibility_with_ddp(rank, world_size, port): assert_close(output, output_compare.narrow(0, 0, 4)) else: assert_close(output, output_compare.narrow(0, 4, 4)) - print(f'output on rank{rank} is correct') + print(f"output on rank{rank} is correct") loss = output.sum() loss.backward() @@ -90,16 +92,16 @@ def check_compatibility_with_ddp(rank, world_size, port): if rank in (1, 3): assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 8, 8)) - print(f'gradient on rank{rank} is correct') + print(f"gradient on rank{rank} is correct") -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_compatibility_with_ddp(): spawn(check_compatibility_with_ddp, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_compatibility_with_ddp() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index 715f62358e2d..aba746f1992d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -5,6 +5,7 @@ try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False except: NO_CODEGEN = True @@ -19,7 +20,6 @@ class MLP(torch.nn.Module): - def __init__(self, in_features): super().__init__() self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False) @@ -34,7 +34,7 @@ def forward(self, x): def check_auto_parallel_with_gemini(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MLP(4).half().cuda() if rank in [0, 1]: input = torch.arange(0, 16).reshape(4, 4).half().cuda() @@ -51,29 +51,29 @@ def check_auto_parallel_with_gemini(rank, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4).half().to('meta')} - gm, solution = initialize_model(model, - meta_args=meta_args, - device_mesh=device_mesh, - return_solution=True, - solver_preference='tp', - shard_option='shard_last_axis') + meta_args = {"x": torch.rand(4, 4).half().to("meta")} + gm, solution = initialize_model( + model, + meta_args=meta_args, + device_mesh=device_mesh, + return_solution=True, + solver_preference="tp", + shard_option="shard_last_axis", + ) if rank == 0: - msg = '| TP strategy combination chosen by auto-parallel solver |' + msg = "| TP strategy combination chosen by auto-parallel solver |" msg_length = len(msg) - print('=' * msg_length) + print("=" * msg_length) print(msg) - print('=' * msg_length) + print("=" * msg_length) for strategy in solution: print(strategy) - print('=' * msg_length) + print("=" * msg_length) - gemini_config = dict(strict_ddp_mode=False, - device=get_current_device(), - placement_policy='cpu', - pin_memory=True, - search_range_m=128) + gemini_config = dict( + strict_ddp_mode=False, device=get_current_device(), placement_policy="cpu", pin_memory=True, search_range_m=128 + ) gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config) optimizer = HybridAdam(gm.parameters(), betas=(0, 0)) @@ -83,28 +83,28 @@ def check_auto_parallel_with_gemini(rank, world_size, port): assert_close(output, output_compare.narrow(0, 0, 4)) else: assert_close(output, output_compare.narrow(0, 4, 4)) - print(f'output on rank{rank} is correct') + print(f"output on rank{rank} is correct") loss = output.sum() optimizer.zero_grad() optimizer.backward(loss) optimizer.step() if rank in (0, 2): - assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 0, 8).flatten()) + assert_close(list(optimizer.optim.state.values())[0]["exp_avg"].half(), grad_compare.narrow(0, 0, 8).flatten()) if rank in (1, 3): - assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 8, 8).flatten()) + assert_close(list(optimizer.optim.state.values())[0]["exp_avg"].half(), grad_compare.narrow(0, 8, 8).flatten()) - print(f'gradient on rank{rank} is correct') + print(f"gradient on rank{rank} is correct") -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') +@run_on_environment_flag(name="AUTO_PARALLEL") +@pytest.mark.skipif(NO_CODEGEN, reason="No codegen found") @pytest.mark.dist @rerun_if_address_is_in_use() def test_auto_parallel_with_gemini(): spawn(check_auto_parallel_with_gemini, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_auto_parallel_with_gemini() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py index a0b407b240e1..a0276acc4293 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py @@ -5,8 +5,8 @@ from torch.fx import GraphModule from transformers.pytorch_utils import Conv1D -from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks @@ -19,7 +19,6 @@ class RepeatBlock(nn.Module): - def __init__(self, intermediate_size, hidden_size): super().__init__() self.c_fc = Conv1D(intermediate_size, hidden_size) @@ -35,13 +34,11 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl class RepeatModel(nn.Module): - def __init__(self, intermediate_size, hidden_size, num_layers): super().__init__() self.blocks = nn.ModuleList([RepeatBlock(intermediate_size, hidden_size) for i in range(num_layers)]) def forward(self, x): - for block in self.blocks: x = block(x) @@ -49,10 +46,9 @@ def forward(self, x): class NonRepeatBlock(nn.Module): - def __init__(self, intermediate_size, hidden_size, layer_index): super().__init__() - intermediate_size //= (layer_index + 1) + intermediate_size //= layer_index + 1 self.c_fc = Conv1D(intermediate_size, hidden_size) self.c_proj = Conv1D(hidden_size, intermediate_size) self.act = torch.nn.ReLU() @@ -66,28 +62,25 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl class NonRepeatModel(nn.Module): - def __init__(self, intermediate_size, hidden_size, num_layers): super().__init__() self.blocks = nn.ModuleList([NonRepeatBlock(intermediate_size, hidden_size, i) for i in range(num_layers)]) def forward(self, x): - for block in self.blocks: x = block(x) return x -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() -@parameterize('model_cls', [RepeatModel, NonRepeatModel]) +@parameterize("model_cls", [RepeatModel, NonRepeatModel]) def test_repeat_blocks(model_cls): - model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS) tracer = ColoTracer(bias_addition_split=True) - input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')} + input_sample = {"x": torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) @@ -110,5 +103,5 @@ def test_repeat_blocks(model_cls): assert len(common_blocks) == 0 -if __name__ == '__main__': +if __name__ == "__main__": test_repeat_blocks() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py index 22a2371311f9..3bb7cc409938 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py @@ -8,7 +8,6 @@ class GPT2MLP(nn.Module): - def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size @@ -34,15 +33,15 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl # 2. The order of split and view op has been changed in the customized GPT2Attention class, the new # order is same as megatron-lm gpt model. class GPT2Attention(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) @@ -68,7 +67,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: - attn_weights = attn_weights / (value.size(-1)**0.5) + attn_weights = attn_weights / (value.size(-1) ** 0.5) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -76,7 +75,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: @@ -100,7 +99,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) def _merge_heads(self, tensor, num_heads, attn_head_size): tensor = tensor.permute(0, 2, 1, 3).contiguous() @@ -113,7 +112,6 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - # query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) qkv = self.c_attn(hidden_states) @@ -121,7 +119,7 @@ def forward( # key = self._split_heads(key, self.num_heads, self.head_dim) # value = self._split_heads(value, self.num_heads, self.head_dim) query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3) - present = (key, value) + (key, value) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) @@ -131,7 +129,6 @@ def forward( class GPT2Block(nn.Module): - def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -205,11 +202,9 @@ def forward( # GPT2Attention mask. attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 - encoder_attention_mask = None - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -267,7 +262,6 @@ def forward( class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index 48d2672c6571..24968e670e3f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -9,6 +9,7 @@ from torch.fx import GraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer @@ -19,6 +20,7 @@ solve_solution, transform_to_sharded_model, ) + NO_CODEGEN = False except: NO_CODEGEN = True @@ -45,14 +47,17 @@ torch.backends.cudnn.benchmark = False -def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, torch.Tensor], - best_sharding_spec_dict: Dict[str, ShardingSpec]): +def _check_module_grad( + module: torch.nn.Module, + origin_param_dict: Dict[str, torch.Tensor], + best_sharding_spec_dict: Dict[str, ShardingSpec], +): for name, param in module.named_parameters(): param_grad = param.grad - name = name.replace('module.', '') + name = name.replace("module.", "") origin_param_grad = origin_param_dict[name].grad - atoms = name.split('.') - new_name = '_'.join(atoms) + atoms = name.split(".") + new_name = "_".join(atoms) if new_name in best_sharding_spec_dict: param_sharding_spec = best_sharding_spec_dict[new_name] grad_to_compare = copy.deepcopy(param_grad) @@ -63,19 +68,19 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor difference = param_grad_global - origin_param_grad avg_diff = difference.abs().sum() / difference.numel() assert avg_diff < 0.001 - print(f'{name} param has {avg_diff} average difference') + print(f"{name} param has {avg_diff} average difference") def check_attention_layer(rank, model_cls, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: - model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') + model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to("cuda") else: - model = model_cls(config=config).to('cuda') + model = model_cls(config=config).to("cuda") test_model = copy.deepcopy(model) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) @@ -84,30 +89,30 @@ def check_attention_layer(rank, model_cls, world_size, port): hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32) if model_cls == GPT2MLP: - input_sample = (hidden_states.to('cuda'),) + input_sample = (hidden_states.to("cuda"),) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { - 'hidden_states': hidden_states.to('meta'), + "hidden_states": hidden_states.to("meta"), } elif model_cls in (GPT2Attention, GPT2Block): input_sample = ( - hidden_states.to('cuda'), - attention_mask.to('cuda'), + hidden_states.to("cuda"), + attention_mask.to("cuda"), ) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { - 'hidden_states': hidden_states.to('meta'), - 'attention_mask': attention_mask.to('meta'), + "hidden_states": hidden_states.to("meta"), + "attention_mask": attention_mask.to("meta"), } else: input_sample = ( - input_ids.to('cuda'), - attention_mask.to('cuda'), + input_ids.to("cuda"), + attention_mask.to("cuda"), ) test_input_sample = copy.deepcopy(input_sample) meta_input_sample = { - 'input_ids': input_ids.to('meta'), - 'attention_mask': attention_mask.to('meta'), + "input_ids": input_ids.to("meta"), + "attention_mask": attention_mask.to("meta"), } physical_mesh_id = torch.arange(0, 4) @@ -122,10 +127,11 @@ def check_attention_layer(rank, model_cls, world_size, port): shape_prop_pass(gm, *meta_input_sample.values()) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') + strategies_constructor = build_strategy_constructor(graph, device_mesh, "standard", "replicated", "standard") solution = solve_solution(gm, strategies_constructor, memory_budget=-1) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh, - strategies_constructor) + gm, sharding_spec_dicts = transform_to_sharded_model( + gm, meta_input_sample, solution, device_mesh, strategies_constructor + ) gm = ModuleWrapper(gm, *sharding_spec_dicts) nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -141,7 +147,7 @@ def check_attention_layer(rank, model_cls, world_size, port): output = gm(*input_sample) assert_close(output, origin_output, rtol=1e-03, atol=1e-03) - #*******************backward starting******************* + # *******************backward starting******************* cuda_rng_state = torch.cuda.get_rng_state() cpu_rng_state = torch.get_rng_state() output.sum().backward() @@ -158,9 +164,9 @@ def check_attention_layer(rank, model_cls, world_size, port): if rank == 0: print("*******************backward finished*******************") - #*******************backward finished******************* + # *******************backward finished******************* - #*******************strategy selected******************* + # *******************strategy selected******************* if rank == 0: print("*******************strategy selected*******************") nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -176,19 +182,19 @@ def check_attention_layer(rank, model_cls, world_size, port): node_memory_cost = node_memory_cost[0] memory_cost += node_memory_cost.activation + node_memory_cost.parameter - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') + print(f"computation cost is {computation_cost}") + print(f"communication cost is {communication_cost}") + print(f"memory cost is {memory_cost}") -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.skipif(NO_CODEGEN, reason="no codegen module") @pytest.mark.dist -@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) +@parameterize("model_cls", [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @rerun_if_address_is_in_use() def test_mlp_layer(model_cls): spawn(check_attention_layer, 4, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_mlp_layer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 5a8c3c4bf5a0..b61cbe170820 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -4,7 +4,6 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh @@ -18,9 +17,9 @@ HIDDEN_DIM = 384 -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() -@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) +@parameterize("model_cls", [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) def test_self_attention_block(model_cls): config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: @@ -32,23 +31,23 @@ def test_self_attention_block(model_cls): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() + ShapeConsistencyManager() tracer = ColoTracer(bias_addition_split=True) if model_cls == GPT2MLP: input_sample = { - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), + "hidden_states": torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to("meta"), } elif model_cls in (GPT2Attention, GPT2Block): input_sample = { - 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), - 'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'), + "hidden_states": torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to("meta"), + "attention_mask": torch.rand(1, SEQ_LENGTH).to("meta"), } else: input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) - input_sample = {k: v.to('meta') for k, v in kwargs.items()} + input_sample = {k: v.to("meta") for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=input_sample) @@ -63,7 +62,7 @@ def test_self_attention_block(model_cls): cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1) - ret = solver.call_solver_serialized_args() + solver.call_solver_serialized_args() strategies_list = solver.last_s_val nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -79,10 +78,10 @@ def test_self_attention_block(model_cls): node_memory_cost = node_memory_cost[0] memory_cost += node_memory_cost.activation + node_memory_cost.parameter - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') + print(f"computation cost is {computation_cost}") + print(f"communication cost is {communication_cost}") + print(f"memory cost is {memory_cost}") -if __name__ == '__main__': +if __name__ == "__main__": test_self_attention_block() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py index d10b222c060d..4dd04c69c8a5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -11,7 +11,6 @@ class LinearModel(nn.Module): - def __init__(self): super().__init__() self.linear1 = nn.Linear(4, 4) @@ -27,12 +26,12 @@ def forward(self, x1, x2): return out -@pytest.mark.skip('meta tensor has some bugs in 1.11') +@pytest.mark.skip("meta tensor has some bugs in 1.11") @clear_cache_before_run() def test_liveness_analysis(): model = LinearModel() tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 4, device='meta'), 'x2': torch.rand(4, 4, device='meta')} + meta_args = {"x1": torch.rand(4, 4, device="meta"), "x2": torch.rand(4, 4, device="meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) shape_prop_pass(gm, *meta_args.values()) @@ -46,8 +45,8 @@ def test_liveness_analysis(): # a variable named `relu` must exist # and this live var must have inplace = True - assert liveness_list[0].all_live_vars.exists('relu') - relu_var = liveness_list[0].all_live_vars.get('relu') + assert liveness_list[0].all_live_vars.exists("relu") + relu_var = liveness_list[0].all_live_vars.get("relu") assert relu_var.is_inplace # the unique vars must be fewer than the all vars since in-place ops exist @@ -56,5 +55,5 @@ def test_liveness_analysis(): assert len(unique_live_vars) + 1 == len(all_live_vars) -if __name__ == '__main__': +if __name__ == "__main__": test_liveness_analysis() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index e0a2133e654e..8831a208cb2f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -7,14 +7,17 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() -@parameterize('func', [ - torch.nn.functional.softmax, - torch.nn.functional.relu, - torch.tanh, - torch.nn.functional.dropout, -]) +@parameterize( + "func", + [ + torch.nn.functional.softmax, + torch.nn.functional.relu, + torch.tanh, + torch.nn.functional.dropout, + ], +) def test_activation_meta_info(func): meta_func = meta_register.get(func) # construct meta tensors @@ -23,13 +26,13 @@ def test_activation_meta_info(func): softmax_dim = 0 # construct operation data - input_data = OperationData(name='input', type=OperationDataType.ARG, data=input_tensor) - output_data = OperationData(name='output', type=OperationDataType.OUTPUT, data=output_tensor) - softmax_dim_data = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim) + input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor) + output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor) + softmax_dim_data = OperationData(name="softmax_dim", type=OperationDataType.ARG, data=softmax_dim) # construct args and kwargs args = [input_data, softmax_dim_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -54,9 +57,17 @@ def test_activation_meta_info(func): bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 - print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_activation_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py index 68ccc7835bc3..ba9e282144b7 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py @@ -3,7 +3,6 @@ import torch.nn as nn from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -12,7 +11,6 @@ class BinaryElementwiseOpModule(nn.Module): - def __init__(self, token=torch.add, shape=64) -> None: super().__init__() self.token = token @@ -33,7 +31,7 @@ def _binary_elementwise_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = BinaryElementwiseOpModule(token=torch.add, shape=1024).cuda() input = torch.rand(32, 1024).cuda() input.requires_grad = True @@ -45,21 +43,23 @@ def _binary_elementwise_mem_test(rank, world_size, port): node_index = 2 # total number of target node strategies strategy_number = 9 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_meta_concrete_info_match(): spawn(_binary_elementwise_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_binary_elementwise_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py index c6f7b88f44a5..45558154547f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -11,7 +11,6 @@ class ConvFunctionModule(nn.Module): - def __init__(self, in_channels=4, out_channels=64, kernel_size=3): super().__init__() self.conv_weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) @@ -32,7 +31,7 @@ def _conv_module_mem_test(rank, world_size, port, bias): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Conv2d(4, 64, 3, padding=1, bias=bias)).cuda() input = torch.rand(4, 4, 64, 64).cuda() input.requires_grad = True @@ -44,16 +43,18 @@ def _conv_module_mem_test(rank, world_size, port, bias): node_index = 1 # total number of target node strategies strategy_number = 16 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_meta_concrete_info_match(bias=False): @@ -71,7 +72,7 @@ def _conv_function_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvFunctionModule().cuda() input = torch.rand(4, 4, 64, 64).cuda() input.requires_grad = True @@ -83,22 +84,24 @@ def _conv_function_mem_test(rank, world_size, port): node_index = 2 # total number of target node strategies strategy_number = 16 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_function_concrete_info_match(): spawn(_conv_function_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": # test_conv_meta_concrete_info_match() test_conv_function_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py index e3f76a95c4a5..5d830d769c2d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -5,11 +5,11 @@ from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': +if torch.__version__ >= "1.12.0": from colossalai.auto_parallel.meta_profiler import meta_register -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() def test_embedding_meta_info(): meta_func = meta_register.get(torch.nn.Embedding) @@ -28,7 +28,7 @@ def test_embedding_meta_info(): # construct args and kwargs args = [input_data, weight_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -52,9 +52,17 @@ def test_embedding_meta_info(): bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 - print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_embedding_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index fb3ded339ddf..639870c89a82 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -11,7 +11,6 @@ class MyModule(nn.Module): - def __init__(self, in_features=64, out_features=128): super().__init__() self.fc_weight = nn.Parameter(torch.randn(out_features, in_features)) @@ -31,7 +30,7 @@ def _linear_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Linear(64, 128, bias=False)).cuda() input = torch.rand(8, 8, 16, 64).cuda() input.requires_grad = True @@ -40,16 +39,18 @@ def _linear_module_mem_test(rank, world_size, port): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) # memory test - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=1, - strategy_number=13, - input_args=[input], - meta_arg_names=["input"]) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=1, + strategy_number=13, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_module_meta_concrete_info_match(): @@ -67,7 +68,7 @@ def _linear_function_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = MyModule().cuda() input = torch.rand(8, 8, 16, 64).cuda() input.requires_grad = True @@ -76,22 +77,24 @@ def _linear_function_mem_test(rank, world_size, port): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) # memory test - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=2, - strategy_number=24, - input_args=[input], - meta_arg_names=["input"]) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=2, + strategy_number=24, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_function_meta_concrete_info_match(): spawn(_linear_function_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": # test_linear_module_meta_concrete_info_match() test_linear_function_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py index 2d2d77f0c637..b182dd02ca76 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py @@ -5,26 +5,27 @@ from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register +if torch.__version__ >= "1.12.0": + from colossalai.auto_parallel.meta_profiler import meta_register -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() @parameterize( - 'tensor_shapes', + "tensor_shapes", [ - [[128], [128]], # dot product - [[64, 128], [128]], # mat-vec - [[128], [128, 64]], # vec-mat - [[64, 64, 128], [128]], # batched mat-vec - [[128], [64, 128, 64]], # vec-batched mat - [[64, 128], [128, 192]], # mat-mat - [[64, 64, 128], [128, 192]], # batched mat-mat - [[64, 128], [64, 128, 192]], # mat-batched mat - [[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims) - [[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims) - ]) + [[128], [128]], # dot product + [[64, 128], [128]], # mat-vec + [[128], [128, 64]], # vec-mat + [[64, 64, 128], [128]], # batched mat-vec + [[128], [64, 128, 64]], # vec-batched mat + [[64, 128], [128, 192]], # mat-mat + [[64, 64, 128], [128, 192]], # batched mat-mat + [[64, 128], [64, 128, 192]], # mat-batched mat + [[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims) + [[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims) + ], +) def test_matmul_function_meta_info(tensor_shapes): meta_func = meta_register.get(torch.matmul) @@ -55,7 +56,7 @@ def test_matmul_function_meta_info(tensor_shapes): # construct args and kwargs args = [input_data, other_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -85,9 +86,17 @@ def test_matmul_function_meta_info(tensor_shapes): compute_cost: TrainCycleItem memory_cost: TrainCycleItem - print_results([input_real_tensor, other_real_tensor], [output_real_tensor], compute_cost, memory_cost, - fwd_allocated, fwd_peak, bwd_allocated, bwd_peak) + print_results( + [input_real_tensor, other_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_matmul_function_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py index 808172977b60..ed809a758dfd 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py @@ -10,7 +10,7 @@ from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results -if torch.__version__ >= '1.12.0': +if torch.__version__ >= "1.12.0": from colossalai.auto_parallel.meta_profiler import meta_register @@ -25,7 +25,7 @@ def _batchnorm_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.BatchNorm2d(128)).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -37,27 +37,32 @@ def _batchnorm_module_mem_test(rank, world_size, port): node_index = 1 # total number of target node strategies strategy_number = 9 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) - - -@run_on_environment_flag(name='AUTO_PARALLEL') + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_batchnorm_meta_concrete_info_match(): spawn(_batchnorm_module_mem_test, 4) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations') -@parameterize('tensor_shape', [ - [256, 1024], - [1024, 256], -]) +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") +@parameterize( + "tensor_shape", + [ + [256, 1024], + [1024, 256], + ], +) def test_layernorm_meta_info(tensor_shape): meta_func = meta_register.get(torch.nn.LayerNorm) @@ -78,7 +83,7 @@ def test_layernorm_meta_info(tensor_shape): # construct args and kwargs args = [input_data, output_data, weight_data, bias_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -108,10 +113,18 @@ def test_layernorm_meta_info(tensor_shape): compute_cost: TrainCycleItem memory_cost: TrainCycleItem - print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_batchnorm_meta_concrete_info_match() test_layernorm_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py index 4cddf4e19fca..bd1deb40ca7b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py @@ -21,7 +21,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.AdaptiveAvgPool2d((16, 16))).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -33,16 +33,18 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): node_index = 1 # total number of target strategies strategy_number = 1 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_adaptiveavgpool_meta_concrete_info_match(): @@ -60,7 +62,7 @@ def _maxpool_module_mem_test(rank, world_size, port): port: port for initializing process group """ disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.MaxPool2d((16, 16))).cuda() input = torch.rand(4, 128, 64, 64).cuda() input.requires_grad = True @@ -72,22 +74,24 @@ def _maxpool_module_mem_test(rank, world_size, port): node_index = 1 # total number of target node strategies strategy_number = 9 - mem_test_for_node_strategy(rank=rank, - model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + mem_test_for_node_strategy( + rank=rank, + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_maxpool_meta_concrete_info_match(): spawn(_maxpool_module_mem_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_adaptiveavgpool_meta_concrete_info_match() test_maxpool_meta_concrete_info_match() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py index 6e8145885d67..a29291e9b4d9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py @@ -6,12 +6,11 @@ from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register +if torch.__version__ >= "1.12.0": + from colossalai.auto_parallel.meta_profiler import meta_register class SplitModule(nn.Module): - def __init__(self) -> None: super().__init__() @@ -19,7 +18,7 @@ def forward(self, x): return x.split(512, dim=0) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() def test_tensor_meta_info(): """test tensor related meta information @@ -45,7 +44,7 @@ def test_tensor_meta_info(): logical_shape=input_tensor.shape, ) split_info_data = OperationData( - name='split_info', + name="split_info", type=OperationDataType.ARG, data=0, logical_shape=None, @@ -53,7 +52,7 @@ def test_tensor_meta_info(): # construct args args = [input_data, output_data, split_info_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -79,8 +78,16 @@ def test_tensor_meta_info(): bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 - print_results([input_real_tensor], output_real_tensor, compute_cost, memory_cost, fwd_allocated, fwd_peak, - bwd_allocated, bwd_peak) + print_results( + [input_real_tensor], + output_real_tensor, + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) if __name__ == "__main__": diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py index b4564312eeb4..64d9ccd3def2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py @@ -5,11 +5,11 @@ from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register +if torch.__version__ >= "1.12.0": + from colossalai.auto_parallel.meta_profiler import meta_register -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() def test_where_meta_info(): meta_func = meta_register.get(torch.where) @@ -49,7 +49,7 @@ def test_where_meta_info(): # construct args and kwargs args = [condition_data, x_data, y_data, output_data] - kwargs = {'inplace': False} + kwargs = {"inplace": False} # estimated results compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) @@ -81,9 +81,17 @@ def test_where_meta_info(): compute_cost: TrainCycleItem memory_cost: TrainCycleItem - print_results([condition_real_tensor, x_real_tensor, y_real_tensor], [output_real_tensor], compute_cost, - memory_cost, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak) + print_results( + [condition_real_tensor, x_real_tensor, y_real_tensor], + [output_real_tensor], + compute_cost, + memory_cost, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_where_meta_info() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py index 4ca85d34da30..e58d15cec50b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -7,6 +7,7 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass @@ -16,29 +17,34 @@ from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -if torch.__version__ >= '1.12.0': +if torch.__version__ >= "1.12.0": from colossalai.auto_parallel.meta_profiler import ShardMetaInfo -def mem_test_for_node_strategy(rank: int, - model: torch.nn.Module, - device_mesh: DeviceMesh, - node_index: int, - strategy_number: int, - input_args: List[torch.Tensor], - meta_arg_names: List[str], - input_kwargs: Dict[str, torch.Tensor] = {}): +def mem_test_for_node_strategy( + rank: int, + model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}, +): for strategy_index in range(strategy_number): # We need to copy the model to avoid do backward more than once in same graph - model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy( - input_kwargs) + model_to_shard, args_to_shard, kwargs_to_shard = ( + copy.deepcopy(model), + copy.deepcopy(input_args), + copy.deepcopy(input_kwargs), + ) tracer = ColoTracer(bias_addition_split=True) input_sample = {} for input_arg, meta_arg_name in zip(input_args, meta_arg_names): - input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') + input_sample[meta_arg_name] = torch.rand(input_arg.shape).to("meta") for meta_kwarg_name, input_kwarg in input_kwargs.items(): - input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') + input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to("meta") graph = tracer.trace(root=model_to_shard, meta_args=input_sample) gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) shape_prop_pass(gm, *input_sample.values()) @@ -57,13 +63,18 @@ def mem_test_for_node_strategy(rank: int, # construct the strategy for the output node placeholder_strategy = list(graph.nodes)[-1].strategies_vector[0] - output_key = next(key for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() - if key.type == OperationDataType.OUTPUT) + output_key = next( + key + for key in target_node.strategies_vector[strategy_index].sharding_specs.keys() + if key.type == OperationDataType.OUTPUT + ) placeholder_strategy.sharding_specs[output_key] = target_node.strategies_vector[strategy_index].sharding_specs[ - output_key] + output_key + ] gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( - gm, solution, device_mesh, strategies_constructor) + gm, solution, device_mesh, strategies_constructor + ) gm = runtime_apply_pass(gm) gm.recompile() gm: GraphModule @@ -76,22 +87,26 @@ def mem_test_for_node_strategy(rank: int, # warmup with torch.no_grad(): - output = gm(*args_to_shard, - sharding_spec_convert_dict=sharding_spec_dict, - origin_node_sharding_spec_dict=origin_spec_dict, - comm_actions_dict=comm_actions_dict, - **kwargs_to_shard) + output = gm( + *args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard, + ) del output # forward memory compare if rank == 0: torch.cuda.reset_peak_memory_stats() mem_stamp0 = torch.cuda.memory_allocated() - output = gm(*args_to_shard, - sharding_spec_convert_dict=sharding_spec_dict, - origin_node_sharding_spec_dict=origin_spec_dict, - comm_actions_dict=comm_actions_dict, - **kwargs_to_shard) + output = gm( + *args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard, + ) if rank == 0: # print forward memory allocated and peak memory stats in kb @@ -113,8 +128,10 @@ def mem_test_for_node_strategy(rank: int, # estimated memory if target_node.op == "call_module": - metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], - target_node.graph.owning_module.get_submodule(target_node.target)) + metainfo = ShardMetaInfo( + target_node.strategies_vector[strategy_index], + target_node.graph.owning_module.get_submodule(target_node.target), + ) else: metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], target_node.target) @@ -134,8 +151,16 @@ def mem_test_for_node_strategy(rank: int, print("=======================") -def print_results(input: List[torch.Tensor], output: List[torch.Tensor], compute_cost: TrainCycleItem, - memory_cost: TrainCycleItem, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak): +def print_results( + input: List[torch.Tensor], + output: List[torch.Tensor], + compute_cost: TrainCycleItem, + memory_cost: TrainCycleItem, + fwd_allocated, + fwd_peak, + bwd_allocated, + bwd_peak, +): """Print the results of the meta information test. Args: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py index 80e6a6c1460c..73a15f3ba4de 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -13,7 +13,6 @@ class AddBMMTensorMethodModule(nn.Module): - def __init__(self, using_kwargs): super().__init__() self.using_kwargs = using_kwargs @@ -27,7 +26,6 @@ def forward(self, bias, x1, x2): class AddBMMTorchFunctionModule(nn.Module): - def __init__(self, using_kwargs): super().__init__() self.using_kwargs = using_kwargs @@ -42,7 +40,7 @@ def forward(self, bias, x1, x2): def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module(using_kwargs).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -57,13 +55,15 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg # construct input args input_args = [bias, x1, x2] # construct meta arg names - meta_arg_names = ['bias', 'x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["bias", "x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer() # graph(): # %bias : torch.Tensor [#users=1] = placeholder[target=bias] @@ -73,13 +73,15 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) # return add - graph = tracer.trace(model, - meta_args={ - 'bias': torch.rand(*bias_shape).to('meta'), - "x1": torch.rand(4, 8, 16).to('meta'), - 'x2': torch.rand(4, 16, 8).to('meta') - }) - gm = ColoGraphModule(model, graph) + graph = tracer.trace( + model, + meta_args={ + "bias": torch.rand(*bias_shape).to("meta"), + "x1": torch.rand(4, 8, 16).to("meta"), + "x2": torch.rand(4, 16, 8).to("meta"), + }, + ) + ColoGraphModule(model, graph) bmm_mod_node = list(graph.nodes)[3] strategies_vector = StrategiesVector(bmm_mod_node) @@ -96,49 +98,49 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] for name in strategy_name_list: print(name) # one batch dim - assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + assert "Sb0 = Sb0 x Sb0" not in strategy_name_list # two batch dim - assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + assert "Sb01 = Sb01 x Sb01" in strategy_name_list # SbSi = SbSi x Sb - assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list - assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + assert "Sb0Si1 = Sb0Si1 x Sb0" in strategy_name_list + assert "Sb1Si0 = Sb1Si0 x Sb1" in strategy_name_list # SbSj = SbR x SbSj - assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list - assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + assert "Sb0Sj1 = Sb0R x Sb0Sj1" in strategy_name_list + assert "Sb1Sj0 = Sb1R x Sb1Sj0" in strategy_name_list # SbR = SbSk x SbSk - assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list - assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + assert "Sb0R = Sb0Sk1 x Sb0Sk1" in strategy_name_list + assert "Sb1R = Sb1Sk0 x Sb1Sk0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] @@ -148,7 +150,7 @@ def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwarg def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (1, 4) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -163,13 +165,15 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por # construct input args input_args = [bias, x1, x2] # construct meta arg names - meta_arg_names = ['bias', 'x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["bias", "x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer() # graph(): @@ -180,13 +184,15 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {}) # return add - graph = tracer.trace(model, - meta_args={ - 'bias': torch.rand(*bias_shape).to('meta'), - "x1": torch.rand(4, 8, 16).to('meta'), - 'x2': torch.rand(4, 16, 8).to('meta') - }) - gm = ColoGraphModule(model, graph) + graph = tracer.trace( + model, + meta_args={ + "bias": torch.rand(*bias_shape).to("meta"), + "x1": torch.rand(4, 8, 16).to("meta"), + "x2": torch.rand(4, 16, 8).to("meta"), + }, + ) + ColoGraphModule(model, graph) bmm_mod_node = list(graph.nodes)[3] strategies_vector = StrategiesVector(bmm_mod_node) @@ -202,33 +208,33 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 1 # one batch dim - assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + assert "Sb0 = Sb0 x Sb0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0] @@ -237,11 +243,11 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por @pytest.mark.skip("skip due to bias cases not ready") -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist -@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) -@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) -@parameterize('using_kwargs', [True, False]) +@parameterize("module", [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize("bias_shape", [[8], [1, 8], [8, 8]]) +@parameterize("using_kwargs", [True, False]) @rerun_if_address_is_in_use() def test_2d_device_mesh(module, bias_shape, using_kwargs): spawn( @@ -254,11 +260,11 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs): @pytest.mark.skip("skip due to bias cases not ready") -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist -@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) -@parameterize('bias_shape', [[8], [1, 8], [8, 8]]) -@parameterize('using_kwargs', [True, False]) +@parameterize("module", [AddBMMTorchFunctionModule, AddBMMTensorMethodModule]) +@parameterize("bias_shape", [[8], [1, 8], [8, 8]]) +@parameterize("using_kwargs", [True, False]) @rerun_if_address_is_in_use() def test_1d_device_mesh(module, bias_shape, using_kwargs): spawn( @@ -270,6 +276,6 @@ def test_1d_device_mesh(module, bias_shape, using_kwargs): ) -if __name__ == '__main__': +if __name__ == "__main__": test_1d_device_mesh() test_2d_device_mesh() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py index fe6554cd81ee..26f9c4ab1e3c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -19,7 +19,6 @@ class AddmmModel(nn.Module): - def __init__(self): super().__init__() @@ -29,7 +28,6 @@ def forward(self, input, m1, m2): class AddmmModel_with_param(nn.Module): - def __init__(self, weight_shape, bias_shape): super().__init__() self.weight = torch.nn.Parameter(torch.rand(weight_shape)) @@ -42,7 +40,7 @@ def forward(self, m1): def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") if model_cls == AddmmModel: model = AddmmModel().cuda() else: @@ -58,10 +56,10 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) # construct input args input_args = [input, m1, m2] # construct meta arg names - meta_arg_names = ['input', 'm1', 'm2'] + meta_arg_names = ["input", "m1", "m2"] meta_args_for_tracer = {} for meta_arg, input_arg in zip(meta_arg_names, input_args): - meta_args_for_tracer[meta_arg] = input_arg.to('meta') + meta_args_for_tracer[meta_arg] = input_arg.to("meta") # the index of addmm node in computation graph node_index = 4 @@ -72,22 +70,24 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) # construct input args input_args = [m1] # construct meta arg names - meta_arg_names = ['m1'] + meta_arg_names = ["m1"] # the index of addmm node in computation graph meta_args_for_tracer = {} for meta_arg, input_arg in zip(meta_arg_names, input_args): - meta_args_for_tracer[meta_arg] = input_arg.to('meta') + meta_args_for_tracer[meta_arg] = input_arg.to("meta") node_index = 4 # strategy number of linear node strategy_number = 14 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - node_type='bias_module') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type="bias_module", + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -117,60 +117,60 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) # check operation data mapping mapping = handler.get_operation_data_mapping() - assert mapping['input'].name == "m1" - assert mapping['input'].data.shape == torch.Size([4, 8]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8]) + assert mapping["input"].name == "m1" + assert mapping["input"].data.shape == torch.Size([4, 8]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8]) - assert mapping['other'].name == "transpose" - assert mapping['other'].data.shape == torch.Size([16, 8]) + assert mapping["other"].name == "transpose" + assert mapping["other"].data.shape == torch.Size([16, 8]) if model_cls == AddmmModel: - assert mapping['other'].type == OperationDataType.ARG + assert mapping["other"].type == OperationDataType.ARG else: - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([8, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([8, 16]) - assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 16]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "linear" + assert mapping["output"].data.shape == torch.Size([4, 16]) + assert mapping["output"].type == OperationDataType.OUTPUT # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('m1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('transpose') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("m1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("transpose") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -178,14 +178,14 @@ def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls) assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist -@parameterize('input_shape', [(16,), (4, 16)]) -@parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) +@parameterize("input_shape", [(16,), (4, 16)]) +@parameterize("model_cls", [AddmmModel, AddmmModel_with_param]) @rerun_if_address_is_in_use() def test_addmm_handler(input_shape, model_cls): spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_addmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py index c3ceef4c7adf..86df7237a219 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -16,7 +16,7 @@ def check_bn_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.BatchNorm2d(16)).cuda() physical_mesh_id = torch.arange(0, 4) @@ -29,18 +29,20 @@ def check_bn_module_handler(rank, world_size, port): # the total number of bn strategies without sync bn mode # TODO: add sync bn strategies after related passes ready strategy_number = 4 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 - meta_args = {"input": torch.rand(4, 16, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 16, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -59,37 +61,37 @@ def check_bn_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16, 64, 64]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16, 64, 64]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "_0" - assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "_0" + assert mapping["output"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # RS = RS x S - assert 'RS0 = RS0 x S0' in strategy_name_list - assert 'RS1 = RS1 x S1' in strategy_name_list + assert "RS0 = RS0 x S0" in strategy_name_list + assert "RS1 = RS1 x S1" in strategy_name_list # RR = RR x R - assert 'RR = RR x R' in strategy_name_list + assert "RR = RR x R" in strategy_name_list # RS01 = RS01 x S01 - assert 'RS01 = RS01 x S01' in strategy_name_list + assert "RS01 = RS01 x S01" in strategy_name_list # temporarily skip the sync bn test # TODO: test sync bn after the implicit runtime pass completed @@ -105,12 +107,12 @@ def check_bn_module_handler(rank, world_size, port): # assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_bn_module_handler(): spawn(check_bn_module_handler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_bn_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py index 800bc11a50e4..e06625e1c42c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -5,7 +5,7 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -22,7 +22,6 @@ class LinearModule(torch.nn.Module): - def __init__(self, weight_shape): super().__init__() self.weight = torch.nn.Parameter(torch.rand(*weight_shape)) @@ -35,7 +34,7 @@ def forward(self, x): def check_linear_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModule(weight_shape=WEIGHT_SHAPE).cuda() physical_mesh_id = torch.arange(0, 4) @@ -49,14 +48,16 @@ def check_linear_module_handler(rank, world_size, port): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['x'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - node_type='bias_module') + meta_arg_names = ["x"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type="bias_module", + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -66,7 +67,7 @@ def check_linear_module_handler(rank, world_size, port): # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {}) # return add - meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')} + meta_args = {"x": torch.rand(4, 4, 4, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -85,72 +86,72 @@ def check_linear_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x" - assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([64, 16]) + assert mapping["input"].name == "x" + assert mapping["input"].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([64, 16]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16, 32]) - assert 'bias' not in mapping + assert "bias" not in mapping - assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "linear" + assert mapping["output"].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('x') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("x") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -158,12 +159,12 @@ def check_linear_module_handler(rank, world_size, port): assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(): spawn(check_linear_module_handler) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py index c29a065d10ba..690f0c12387c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -19,7 +19,6 @@ class LinearModule(torch.nn.Module): - def __init__(self, in_features, out_features, bias): super().__init__() self.linear = torch.nn.Linear(in_features, out_features, bias=bias) @@ -31,7 +30,7 @@ def forward(self, x): def check_linear_module_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModule(16, 32, bias=bias).cuda() physical_mesh_id = torch.arange(0, 4) @@ -45,17 +44,19 @@ def check_linear_module_handler(rank, world_size, port, bias): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['x'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - node_type='bias_module') + meta_arg_names = ["x"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + node_type="bias_module", + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')} + meta_args = {"x": torch.rand(4, 4, 4, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -74,72 +75,72 @@ def check_linear_module_handler(rank, world_size, port, bias): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x" - assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([64, 16]) + assert mapping["input"].name == "x" + assert mapping["input"].data.shape == torch.Size([4, 4, 4, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([64, 16]) - assert mapping['other'].name == "linear_weight" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "linear_weight" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16, 32]) - assert 'bias' not in mapping + assert "bias" not in mapping - assert mapping['output'].name == "linear" - assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "linear" + assert mapping["output"].data.shape == torch.Size([4, 4, 4, 32]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('x') - weight_sharding_spec = strategy.get_sharding_spec_by_name('linear_weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("x") + weight_sharding_spec = strategy.get_sharding_spec_by_name("linear_weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -147,12 +148,12 @@ def check_linear_module_handler(rank, world_size, port, bias): assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(bias=True): spawn(check_linear_module_handler, bias=bias) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py index 83f3aafe220e..5b2e2ab49f6d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -16,10 +16,9 @@ def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") class BinaryElementwiseOpModel(nn.Module): - def __init__(self, op): super().__init__() self.op = op @@ -41,16 +40,18 @@ def forward(self, x1, x2): # construct input args input_args = [x1, x2] # construct meta arg names - meta_arg_names = ['x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')} + meta_args = {"x1": torch.rand(4, 4).to("meta"), "x2": torch.rand([4] * other_dim).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -70,23 +71,23 @@ def forward(self, x1, x2): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4] * other_dim) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 4]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4] * other_dim) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 4]) - assert mapping['output'].name == str(op_node) - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([4, 4]) + assert mapping["output"].name == str(op_node) + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([4, 4]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -95,19 +96,19 @@ def forward(self, x1, x2): assert len(strategy_name_list) == 9 # check if the sharding strategy is correct - assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list - assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list - assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list - assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list - assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list - assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list - assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list - assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list - assert '[R, R] = [R, R] [R, R]' in strategy_name_list + assert "[S0, S1] = [S0, S1] [S0, S1]" in strategy_name_list + assert "[S1, S0] = [S1, S0] [S1, S0]" in strategy_name_list + assert "[S01, R] = [S01, R] [S01, R]" in strategy_name_list + assert "[R, S01] = [R, S01] [R, S01]" in strategy_name_list + assert "[S0, R] = [S0, R] [S0, R]" in strategy_name_list + assert "[R, S0] = [R, S0] [R, S0]" in strategy_name_list + assert "[S1, R] = [S1, R] [S1, R]" in strategy_name_list + assert "[R, S1] = [R, S1] [R, S1]" in strategy_name_list + assert "[R, R] = [R, R] [R, R]" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) # make sure the sharding spec is the same for input and output @@ -121,7 +122,6 @@ def forward(self, x1, x2): class BEOpModelWithNodeConst(nn.Module): - def __init__(self, op): super().__init__() self.op = op @@ -133,7 +133,6 @@ def forward(self, x1): class BEOpModelWithIntConst(nn.Module): - def __init__(self, op, const): super().__init__() self.op = op @@ -146,7 +145,7 @@ def forward(self, x1): def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -163,15 +162,17 @@ def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_ # construct input args input_args = [x1] # construct meta arg names - meta_arg_names = ['x1'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 4).to('meta')} + meta_args = {"x1": torch.rand(4, 4).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -188,17 +189,17 @@ def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_ # check operation data mapping mapping = handler.get_operation_data_mapping() - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4]) - assert mapping['output'].name == str(op_node) - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([4, 4]) + assert mapping["output"].name == str(op_node) + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([4, 4]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -207,27 +208,27 @@ def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_ assert len(strategy_name_list) == 9 # check if the sharding strategy is correct - assert '[S0, S1] = [S0, S1] [S0, S1]' in strategy_name_list - assert '[S1, S0] = [S1, S0] [S1, S0]' in strategy_name_list - assert '[S01, R] = [S01, R] [S01, R]' in strategy_name_list - assert '[R, S01] = [R, S01] [R, S01]' in strategy_name_list - assert '[S0, R] = [S0, R] [S0, R]' in strategy_name_list - assert '[R, S0] = [R, S0] [R, S0]' in strategy_name_list - assert '[S1, R] = [S1, R] [S1, R]' in strategy_name_list - assert '[R, S1] = [R, S1] [R, S1]' in strategy_name_list - assert '[R, R] = [R, R] [R, R]' in strategy_name_list + assert "[S0, S1] = [S0, S1] [S0, S1]" in strategy_name_list + assert "[S1, S0] = [S1, S0] [S1, S0]" in strategy_name_list + assert "[S01, R] = [S01, R] [S01, R]" in strategy_name_list + assert "[R, S01] = [R, S01] [R, S01]" in strategy_name_list + assert "[S0, R] = [S0, R] [S0, R]" in strategy_name_list + assert "[R, S0] = [R, S0] [R, S0]" in strategy_name_list + assert "[S1, R] = [S1, R] [S1, R]" in strategy_name_list + assert "[R, S1] = [R, S1] [R, S1]" in strategy_name_list + assert "[R, R] = [R, R] [R, R]" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") output_sharding_spec = strategy.get_sharding_spec_by_name(str(op_node)) # make sure the sharding spec is the same for input and output assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('op', [torch.add]) -@parameterize('other_dim', [1, 2]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("op", [torch.add]) +@parameterize("other_dim", [1, 2]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_tensor(op, other_dim): @@ -239,10 +240,10 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim): ) -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('op', [torch.add]) -@parameterize('other_dim', [1, 2]) -@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("op", [torch.add]) +@parameterize("other_dim", [1, 2]) +@parameterize("model_cls", [BEOpModelWithNodeConst, BEOpModelWithIntConst]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): @@ -255,6 +256,6 @@ def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): ) -if __name__ == '__main__': +if __name__ == "__main__": test_binary_elementwise_handler_with_tensor() test_binary_elementwise_handler_with_int() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index f4fdc458f80e..29df12832241 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -15,20 +15,18 @@ class BMMTensorMethodModule(nn.Module): - def forward(self, x1, x2): return x1.bmm(x2) class BMMTorchFunctionModule(nn.Module): - def forward(self, x1, x2): return torch.bmm(x1, x2) def check_2d_device_mesh(rank, module, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -42,15 +40,17 @@ def check_2d_device_mesh(rank, module, world_size, port): # construct input args input_args = [x1, x2] # construct meta arg names - meta_arg_names = ['x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')} + meta_args = {"x1": torch.rand(4, 8, 16).to("meta"), "x2": torch.rand(4, 16, 8).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -70,48 +70,48 @@ def check_2d_device_mesh(rank, module, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # one batch dim - assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list + assert "Sb0 = Sb0 x Sb0" not in strategy_name_list # two batch dim - assert 'Sb01 = Sb01 x Sb01' in strategy_name_list + assert "Sb01 = Sb01 x Sb01" in strategy_name_list # SbSi = SbSi x Sb - assert 'Sb0Si1 = Sb0Si1 x Sb0' in strategy_name_list - assert 'Sb1Si0 = Sb1Si0 x Sb1' in strategy_name_list + assert "Sb0Si1 = Sb0Si1 x Sb0" in strategy_name_list + assert "Sb1Si0 = Sb1Si0 x Sb1" in strategy_name_list # SbSj = SbR x SbSj - assert 'Sb0Sj1 = Sb0R x Sb0Sj1' in strategy_name_list - assert 'Sb1Sj0 = Sb1R x Sb1Sj0' in strategy_name_list + assert "Sb0Sj1 = Sb0R x Sb0Sj1" in strategy_name_list + assert "Sb1Sj0 = Sb1R x Sb1Sj0" in strategy_name_list # SbR = SbSk x SbSk - assert 'Sb0R = Sb0Sk1 x Sb0Sk1' in strategy_name_list - assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list + assert "Sb0R = Sb0Sk1 x Sb0Sk1" in strategy_name_list + assert "Sb1R = Sb1Sk0 x Sb1Sk0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -121,7 +121,7 @@ def check_2d_device_mesh(rank, module, world_size, port): def check_1d_device_mesh(rank, module, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = module().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (1, 4) @@ -135,15 +135,17 @@ def check_1d_device_mesh(rank, module, world_size, port): # construct input args input_args = [x1, x2] # construct meta arg names - meta_arg_names = ['x1', 'x2'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["x1", "x2"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')} + meta_args = {"x1": torch.rand(4, 8, 16).to("meta"), "x2": torch.rand(4, 16, 8).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -162,33 +164,33 @@ def check_1d_device_mesh(rank, module, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 8, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 8, 16]) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 8, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 8, 16]) - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([4, 16, 8]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 8]) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([4, 16, 8]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 8]) - assert mapping['output'].name == "bmm" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 8, 8]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "bmm" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 8, 8]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 1 # one batch dim - assert 'Sb0 = Sb0 x Sb0' in strategy_name_list + assert "Sb0 = Sb0 x Sb0" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('bmm') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("bmm") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -196,9 +198,9 @@ def check_1d_device_mesh(rank, module, world_size, port): assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) -@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("module", [BMMTensorMethodModule, BMMTorchFunctionModule]) +@parameterize("module", [BMMTensorMethodModule, BMMTorchFunctionModule]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_bmm_handler(module): @@ -206,5 +208,5 @@ def test_bmm_handler(module): spawn(check_1d_device_mesh, 4, module=module) -if __name__ == '__main__': +if __name__ == "__main__": test_bmm_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index f9632b1cd8f9..8a37dd9256dd 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -16,7 +16,7 @@ def check_conv_module_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -32,14 +32,16 @@ def check_conv_module_handler(rank, world_size, port, bias): node_index = 1 # total number of conv strategies strategy_number = 16 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -58,76 +60,76 @@ def check_conv_module_handler(rank, world_size, port, bias): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" + assert mapping["input"].name == "input_1" # assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4, 64, 64]) - assert mapping['other'].name == "weight" + assert mapping["other"].name == "weight" # assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + assert mapping["other"].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([4, 16, 3, 3]) if bias: - assert mapping['bias'].name == "bias" + assert mapping["bias"].name == "bias" # assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "_0" + assert mapping["output"].name == "_0" # assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert "S0S1 = S0R x RS1" in strategy_name_list + assert "S1S0 = S1R x RS0" in strategy_name_list # SR = SR x RR - assert 'S0R = S0R x RR' in strategy_name_list - assert 'S1R = S1R x RR' in strategy_name_list + assert "S0R = S0R x RR" in strategy_name_list + assert "S1R = S1R x RR" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert "S0R = S0S1 x S1R" in strategy_name_list + assert "S1R = S1S0 x S0R" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR' in strategy_name_list + assert "S01R = S01R x RR" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("_0") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] @@ -141,7 +143,6 @@ def check_conv_module_handler(rank, world_size, port, bias): class ConvModel(nn.Module): - def __init__(self): super().__init__() @@ -152,7 +153,7 @@ def forward(self, input, others, bias=None): def check_conv_function_handler(rank, world_size, port, bias): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = ConvModel().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -160,22 +161,24 @@ def check_conv_function_handler(rank, world_size, port, bias): input = torch.rand(4, 4, 64, 64).cuda() others = torch.rand(16, 4, 3, 3).cuda() input_args = [input, others] - meta_arg_names = ['input', 'others'] + meta_arg_names = ["input", "others"] input_kwargs = {} # total number of conv strategies strategy_number = 16 node_index = 2 if bias: bias_tensor = torch.rand(16).cuda() - input_kwargs['bias'] = bias_tensor + input_kwargs["bias"] = bias_tensor node_index += 1 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - input_kwargs=input_kwargs) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs, + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -183,9 +186,9 @@ def check_conv_function_handler(rank, world_size, port, bias): # %others : torch.Tensor [#users=1] = placeholder[target=others] # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {}) # return conv2d - meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta"), "others": torch.rand(16, 4, 3, 3).to("meta")} if bias: - meta_args['bias'] = torch.rand(16).to('meta') + meta_args["bias"] = torch.rand(16).to("meta") graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -208,76 +211,76 @@ def check_conv_function_handler(rank, world_size, port, bias): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4, 64, 64]) - assert mapping['other'].name == "others" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3]) + assert mapping["other"].name == "others" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([16, 4, 3, 3]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([4, 16, 3, 3]) if bias: - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.is_meta + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.ARG + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "conv2d" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "conv2d" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SS = SR x RS - assert 'S0S1 = S0R x RS1' in strategy_name_list - assert 'S1S0 = S1R x RS0' in strategy_name_list + assert "S0S1 = S0R x RS1" in strategy_name_list + assert "S1S0 = S1R x RS0" in strategy_name_list # SR = SR x RR - assert 'S0R = S0R x RR' in strategy_name_list - assert 'S1R = S1R x RR' in strategy_name_list + assert "S0R = S0R x RR" in strategy_name_list + assert "S1R = S1R x RR" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R' in strategy_name_list - assert 'S1R = S1S0 x S0R' in strategy_name_list + assert "S0R = S0S1 x S1R" in strategy_name_list + assert "S1R = S1S0 x S0R" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR' in strategy_name_list + assert "S01R = S01R x RR" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('others') - output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("others") + output_sharding_spec = strategy.get_sharding_spec_by_name("conv2d") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0] @@ -290,7 +293,7 @@ def check_conv_function_handler(rank, world_size, port, bias): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist # We temporarily ban the bias option before doing bias add # before all reduce communication may encounter correctness issue. @@ -300,7 +303,7 @@ def test_conv_module_handler(bias=False): spawn(check_conv_module_handler, 4, bias=bias) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist # We temporarily ban the bias option before doing bias add # before all reduce communication may encounter correctness issue. @@ -310,6 +313,6 @@ def test_conv_function_handler(bias=False): spawn(check_conv_function_handler, 4, bias=bias) -if __name__ == '__main__': +if __name__ == "__main__": test_conv_module_handler() test_conv_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py index 64f56ba98e2b..ce2ae4248fce 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py @@ -12,7 +12,6 @@ class ReshapeModel(nn.Module): - def __init__(self): super().__init__() @@ -22,7 +21,7 @@ def forward(self, input, other): return reshape_node -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_reshape_handler(): model = ReshapeModel() @@ -34,8 +33,8 @@ def test_reshape_handler(): # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # return view meta_args = { - "input": torch.rand(4, 4, 64, 64).to('meta'), - "other": torch.rand(16, 4, 3, 3).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), + "other": torch.rand(16, 4, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -50,14 +49,14 @@ def test_reshape_handler(): conv_strategies_vector = StrategiesVector(conv_mod_node) # build handler - conv_handler = ConvFunctionHandler(node=conv_mod_node, - device_mesh=device_mesh, - strategies_vector=conv_strategies_vector) + conv_handler = ConvFunctionHandler( + node=conv_mod_node, device_mesh=device_mesh, strategies_vector=conv_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - reshape_handler = DefaultReshapeHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=reshape_strategies_vector) + setattr(conv_mod_node, "strategies_vector", conv_strategies_vector) + reshape_handler = DefaultReshapeHandler( + node=reshape_node, device_mesh=device_mesh, strategies_vector=reshape_strategies_vector + ) reshape_handler.register_strategy(compute_resharding_cost=False) @@ -69,20 +68,20 @@ def test_reshape_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "conv2d" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].name == "conv2d" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16, 62, 62]) - assert mapping['output'].name == "view" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([2, 123008]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "view" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([2, 123008]) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(reshape_strategies_vector) == len(conv_strategies_vector) -if __name__ == '__main__': +if __name__ == "__main__": test_reshape_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py index 4fa0313b1cb5..9ac6ba95da48 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py @@ -22,7 +22,6 @@ class EmbeddingModule(nn.Module): - def __init__(self, num_embeddings, embedding_dims): super().__init__() self.embedding = nn.Embedding(num_embeddings, embedding_dims) @@ -34,7 +33,7 @@ def forward(self, input): def check_embedding_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda() # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -51,15 +50,17 @@ def check_embedding_module_handler(rank, world_size, port): node_index = 1 # total number of embedding strategies strategy_number = 19 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input], - meta_arg_names=['input']) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input], + meta_arg_names=["input"], + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta')} + meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -78,60 +79,60 @@ def check_embedding_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" + assert mapping["input"].name == "input_1" # assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([1024]) + assert mapping["input"].data.shape == torch.Size([4, 16, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([1024]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['output'].name == "embedding" - assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + assert mapping["output"].name == "embedding" + assert mapping["output"].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # RR = RR x RR - assert 'RR = R x RR' in strategy_name_list + assert "RR = R x RR" in strategy_name_list # SR = SR x RR - assert 'S0R = S0 x RR_0' in strategy_name_list - assert 'S0R = S0 x RR_1' in strategy_name_list - assert 'S0R = S0 x RR_2' in strategy_name_list - assert 'S1R = S1 x RR_0' in strategy_name_list - assert 'S1R = S1 x RR_1' in strategy_name_list - assert 'S1R = S1 x RR_2' in strategy_name_list + assert "S0R = S0 x RR_0" in strategy_name_list + assert "S0R = S0 x RR_1" in strategy_name_list + assert "S0R = S0 x RR_2" in strategy_name_list + assert "S1R = S1 x RR_0" in strategy_name_list + assert "S1R = S1 x RR_1" in strategy_name_list + assert "S1R = S1 x RR_2" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0 x RS1_0' in strategy_name_list - assert 'S0S1 = S0 x RS1_1' in strategy_name_list - assert 'S0S1 = S0 x RS1_2' in strategy_name_list - assert 'S1S0 = S1 x RS0_0' in strategy_name_list - assert 'S1S0 = S1 x RS0_1' in strategy_name_list - assert 'S1S0 = S1 x RS0_2' in strategy_name_list + assert "S0S1 = S0 x RS1_0" in strategy_name_list + assert "S0S1 = S0 x RS1_1" in strategy_name_list + assert "S0S1 = S0 x RS1_2" in strategy_name_list + assert "S1S0 = S1 x RS0_0" in strategy_name_list + assert "S1S0 = S1 x RS0_1" in strategy_name_list + assert "S1S0 = S1 x RS0_2" in strategy_name_list # RS= RR x RS - assert 'RS0 = R x RS0' in strategy_name_list - assert 'RS1 = R x RS1' in strategy_name_list + assert "RS0 = R x RS0" in strategy_name_list + assert "RS1 = R x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01 x RR_0' in strategy_name_list - assert 'S01R = S01 x RR_1' in strategy_name_list - assert 'S01R = S01 x RR_2' in strategy_name_list + assert "S01R = S01 x RR_0" in strategy_name_list + assert "S01R = S01 x RR_1" in strategy_name_list + assert "S01R = S01 x RR_2" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = R x RS01' in strategy_name_list + assert "RS01 = R x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("embedding") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] @@ -139,7 +140,6 @@ def check_embedding_module_handler(rank, world_size, port): class EmbeddingFunction(nn.Module): - def __init__(self): super().__init__() @@ -150,7 +150,7 @@ def forward(self, input, others): def check_embedding_function_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = EmbeddingFunction().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -159,18 +159,20 @@ def check_embedding_function_handler(rank, world_size, port): input = input.to(torch.int64).cuda() others = torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).cuda() input_args = [input, others] - meta_arg_names = ['input', 'others'] + meta_arg_names = ["input", "others"] input_kwargs = {} # total number of embedding strategies strategy_number = 19 node_index = 2 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names, - input_kwargs=input_kwargs) + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + input_kwargs=input_kwargs, + ) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] @@ -178,8 +180,8 @@ def check_embedding_function_handler(rank, world_size, port): # %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False}) # return embedding meta_args = { - "input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta'), - "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta') + "input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to("meta"), + "others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -189,9 +191,9 @@ def check_embedding_function_handler(rank, world_size, port): strategies_vector = StrategiesVector(embedding_node) # build handler - handler = EmbeddingFunctionHandler(node=embedding_node, - device_mesh=device_mesh, - strategies_vector=strategies_vector) + handler = EmbeddingFunctionHandler( + node=embedding_node, device_mesh=device_mesh, strategies_vector=strategies_vector + ) # check operation data mapping mapping = handler.get_operation_data_mapping() @@ -202,82 +204,82 @@ def check_embedding_function_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([1024]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 16, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([1024]) - assert mapping['other'].name == "others" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].name == "others" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS]) - assert mapping['output'].name == "embedding" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) + assert mapping["output"].name == "embedding" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS]) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size([1024, EMBEDDING_DIMS]) handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # RR = RR x RR - assert 'RR = R x RR' in strategy_name_list + assert "RR = R x RR" in strategy_name_list # SR = SR x RR - assert 'S0R = S0 x RR_0' in strategy_name_list - assert 'S0R = S0 x RR_1' in strategy_name_list - assert 'S0R = S0 x RR_2' in strategy_name_list - assert 'S1R = S1 x RR_0' in strategy_name_list - assert 'S1R = S1 x RR_1' in strategy_name_list - assert 'S1R = S1 x RR_2' in strategy_name_list + assert "S0R = S0 x RR_0" in strategy_name_list + assert "S0R = S0 x RR_1" in strategy_name_list + assert "S0R = S0 x RR_2" in strategy_name_list + assert "S1R = S1 x RR_0" in strategy_name_list + assert "S1R = S1 x RR_1" in strategy_name_list + assert "S1R = S1 x RR_2" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0 x RS1_0' in strategy_name_list - assert 'S0S1 = S0 x RS1_1' in strategy_name_list - assert 'S0S1 = S0 x RS1_2' in strategy_name_list - assert 'S1S0 = S1 x RS0_0' in strategy_name_list - assert 'S1S0 = S1 x RS0_1' in strategy_name_list - assert 'S1S0 = S1 x RS0_2' in strategy_name_list + assert "S0S1 = S0 x RS1_0" in strategy_name_list + assert "S0S1 = S0 x RS1_1" in strategy_name_list + assert "S0S1 = S0 x RS1_2" in strategy_name_list + assert "S1S0 = S1 x RS0_0" in strategy_name_list + assert "S1S0 = S1 x RS0_1" in strategy_name_list + assert "S1S0 = S1 x RS0_2" in strategy_name_list # RS= RR x RS - assert 'RS0 = R x RS0' in strategy_name_list - assert 'RS1 = R x RS1' in strategy_name_list + assert "RS0 = R x RS0" in strategy_name_list + assert "RS1 = R x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01 x RR_0' in strategy_name_list - assert 'S01R = S01 x RR_1' in strategy_name_list - assert 'S01R = S01 x RR_2' in strategy_name_list + assert "S01R = S01 x RR_0" in strategy_name_list + assert "S01R = S01 x RR_1" in strategy_name_list + assert "S01R = S01 x RR_2" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = R x RS01' in strategy_name_list + assert "RS01 = R x RS01" in strategy_name_list for strategy in strategies_vector: - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('others') - output_sharding_spec = strategy.get_sharding_spec_by_name('embedding') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("others") + output_sharding_spec = strategy.get_sharding_spec_by_name("embedding") # make sure the sharding matches across different operation data assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1] assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1] -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_module_handler(): spawn(check_embedding_module_handler, 4) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_function_handler(): spawn(check_embedding_function_handler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_embedding_module_handler() test_embedding_function_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py index a089df743ec0..2c464f64d8ca 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -12,7 +12,6 @@ class GetattrModel(nn.Module): - def __init__(self): super().__init__() self.conv = nn.Conv2d(4, 16, 3, padding=1, bias=False) @@ -22,7 +21,7 @@ def forward(self, input): return weight -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") @clear_cache_before_run() def test_getattr_handler(): model = GetattrModel() @@ -31,7 +30,7 @@ def test_getattr_handler(): # %input_1 : torch.Tensor [#users=0] = placeholder[target=input] # %conv_weight : [#users=1] = get_attr[target=conv.weight] # return conv_weight - meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -42,9 +41,9 @@ def test_getattr_handler(): getattr_strategies_vector = StrategiesVector(getattr_node) # build handler - getattr_handler = GetattrHandler(node=getattr_node, - device_mesh=device_mesh, - strategies_vector=getattr_strategies_vector) + getattr_handler = GetattrHandler( + node=getattr_node, device_mesh=device_mesh, strategies_vector=getattr_strategies_vector + ) getattr_handler.register_strategy(compute_resharding_cost=False) @@ -56,20 +55,20 @@ def test_getattr_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['output'].name == "conv_weight" - assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3)) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "conv_weight" + assert mapping["output"].data.shape == torch.Size((16, 4, 3, 3)) + assert mapping["output"].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in getattr_handler.strategies_vector] - assert 'get_attr [S0, S1, R, R]' in strategy_name_list - assert 'get_attr [S1, S0, R, R]' in strategy_name_list - assert 'get_attr [S01, R, R, R]' in strategy_name_list - assert 'get_attr [R, S01, R, R]' in strategy_name_list - assert 'get_attr [S0, R, R, R]' in strategy_name_list - assert 'get_attr [R, S0, R, R]' in strategy_name_list - assert 'get_attr [S1, R, R, R]' in strategy_name_list - assert 'get_attr [R, S1, R, R]' in strategy_name_list - assert 'get_attr [R, R, R, R]' in strategy_name_list + assert "get_attr [S0, S1, R, R]" in strategy_name_list + assert "get_attr [S1, S0, R, R]" in strategy_name_list + assert "get_attr [S01, R, R, R]" in strategy_name_list + assert "get_attr [R, S01, R, R]" in strategy_name_list + assert "get_attr [S0, R, R, R]" in strategy_name_list + assert "get_attr [R, S0, R, R]" in strategy_name_list + assert "get_attr [S1, R, R, R]" in strategy_name_list + assert "get_attr [R, S1, R, R]" in strategy_name_list + assert "get_attr [R, R, R, R]" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_getattr_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index a2e0968b18bb..cf802a228034 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -1,5 +1,3 @@ -from functools import partial - import pytest import torch import torch.nn as nn @@ -21,7 +19,6 @@ class GetItemFromTensorModel(nn.Module): - def __init__(self, getitem_index): super().__init__() self.getitem_index = getitem_index @@ -34,12 +31,12 @@ def forward(self, input, other): def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = GetItemFromTensorModel(getitem_index=getitem_index) - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -49,18 +46,20 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) meta_args = { - "input": torch.rand(8, 16, 64, 32).to('meta'), - "other": torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -72,14 +71,14 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): linear_strategies_vector = StrategiesVector(linear_mod_node) # build handler - linear_handler = LinearFunctionHandler(node=linear_mod_node, - device_mesh=device_mesh, - strategies_vector=linear_strategies_vector) + linear_handler = LinearFunctionHandler( + node=linear_mod_node, device_mesh=device_mesh, strategies_vector=linear_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(linear_mod_node, 'strategies_vector', linear_strategies_vector) - getitem_handler = GetItemHandler(node=getitem_mod_node, - device_mesh=device_mesh, - strategies_vector=getitem_strategies_vector) + setattr(linear_mod_node, "strategies_vector", linear_strategies_vector) + getitem_handler = GetItemHandler( + node=getitem_mod_node, device_mesh=device_mesh, strategies_vector=getitem_strategies_vector + ) getitem_handler.register_strategy(compute_resharding_cost=False) # check operation data mapping @@ -94,17 +93,16 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): assert len(getitem_strategies_vector) == len(linear_strategies_vector) -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() # @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) -@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) +@parameterize("getitem_index", [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) def test_getitem_from_tensor_handler(getitem_index): spawn(check_getitem_from_tensor_handler, 4) class GetItemFromTupleModel(nn.Module): - def __init__(self): super().__init__() @@ -114,7 +112,7 @@ def forward(self, input): return x -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_getitem_from_tuple_handler(): model = GetItemFromTupleModel() @@ -125,7 +123,7 @@ def test_getitem_from_tuple_handler(): # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {}) # return getitem meta_args = { - "input": torch.rand(4, 4, 64, 64).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -146,20 +144,20 @@ def test_getitem_from_tuple_handler(): node=input_node, device_mesh=device_mesh, strategies_vector=input_strategies_vector, - placeholder_option='replicated', + placeholder_option="replicated", ) input_handler.register_strategy(compute_resharding_cost=False) - setattr(input_node, 'strategies_vector', input_strategies_vector) - split_handler = DefaultReshapeHandler(node=split_node, - device_mesh=device_mesh, - strategies_vector=split_strategies_vector) + setattr(input_node, "strategies_vector", input_strategies_vector) + split_handler = DefaultReshapeHandler( + node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector + ) split_handler.register_strategy(compute_resharding_cost=False) - setattr(split_node, 'strategies_vector', split_strategies_vector) - getitem_handler = GetItemHandler(node=getitem_node, - device_mesh=device_mesh, - strategies_vector=getitem_strategies_vector) + setattr(split_node, "strategies_vector", split_strategies_vector) + getitem_handler = GetItemHandler( + node=getitem_node, device_mesh=device_mesh, strategies_vector=getitem_strategies_vector + ) getitem_handler.register_strategy(compute_resharding_cost=False) - setattr(getitem_node, 'strategies_vector', getitem_strategies_vector) + setattr(getitem_node, "strategies_vector", getitem_strategies_vector) # check operation data mapping mapping = getitem_handler.get_operation_data_mapping() @@ -169,23 +167,23 @@ def test_getitem_from_tuple_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "split" - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) + assert mapping["input"].name == "split" + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == (torch.Size([2, 4, 64, 64]), torch.Size([2, 4, 64, 64])) - assert mapping['index'].name == "index" - assert isinstance(mapping['index'].data, int) - assert mapping['index'].type == OperationDataType.ARG + assert mapping["index"].name == "index" + assert isinstance(mapping["index"].data, int) + assert mapping["index"].type == OperationDataType.ARG - assert mapping['output'].name == "getitem" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([2, 4, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "getitem" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([2, 4, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(getitem_strategies_vector) == len(split_strategies_vector) -if __name__ == '__main__': +if __name__ == "__main__": test_getitem_from_tensor_handler() test_getitem_from_tuple_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py index ad72c2026b9a..59a66bc6a5d6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -17,7 +17,7 @@ def check_ln_module_handler(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.LayerNorm(16)).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -30,19 +30,21 @@ def check_ln_module_handler(rank, world_size, port): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['input'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["input"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 - meta_args = {"input": torch.rand(4, 16).to('meta')} + meta_args = {"input": torch.rand(4, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -62,45 +64,45 @@ def check_ln_module_handler(rank, world_size, port): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size([4, 16]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size([4, 16]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16]) - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16]) - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([16]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([16]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([16]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([16]) - assert mapping['output'].name == "_0" - assert mapping['output'].data.shape == torch.Size([4, 16]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "_0" + assert mapping["output"].data.shape == torch.Size([4, 16]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # SR = SR x R - assert '[S0, R] = [S0, R] x [R]' in strategy_name_list - assert '[S1, R] = [S1, R] x [R]' in strategy_name_list + assert "[S0, R] = [S0, R] x [R]" in strategy_name_list + assert "[S1, R] = [S1, R] x [R]" in strategy_name_list # RR = RR x R - assert 'RR = RR x R' in strategy_name_list + assert "RR = RR x R" in strategy_name_list # S01R = S01R x R - assert '[S01, R] = [S01, R] x [R]' in strategy_name_list + assert "[S01, R] = [S01, R] x [R]" in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() def test_ln_module_handler(): spawn(check_ln_module_handler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_ln_module_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index ec695cd8f7b9..da88b735f7c1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -23,7 +23,7 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -39,13 +39,15 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): # construct input args input_args = [input] # construct meta arg names - meta_arg_names = ['input'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["input"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) meta_args = {"input": torch.rand(input_shape).cuda()} @@ -68,86 +70,86 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size(input_shape) - assert mapping['input'].type == OperationDataType.ARG - input_logical_shape = mapping['input'].data.view(-1, 16).shape - assert mapping['input'].logical_shape == input_logical_shape + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size(input_shape) + assert mapping["input"].type == OperationDataType.ARG + input_logical_shape = mapping["input"].data.view(-1, 16).shape + assert mapping["input"].logical_shape == input_logical_shape - assert mapping['other'].name == "weight" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "weight" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.PARAM + assert mapping["other"].logical_shape == torch.Size([16, 32]) if bias: - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([32]) - assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['bias'].logical_shape == torch.Size([32]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([32]) + assert mapping["bias"].type == OperationDataType.PARAM + assert mapping["bias"].logical_shape == torch.Size([32]) - assert mapping['output'].name == "_0" + assert mapping["output"].name == "_0" output_shape = input_shape[:-1] + (32,) - assert mapping['output'].data.shape == torch.Size(output_shape) - assert mapping['output'].type == OperationDataType.OUTPUT - output_logical_shape = mapping['output'].data.view(-1, 32).shape - assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + assert mapping["output"].data.shape == torch.Size(output_shape) + assert mapping["output"].type == OperationDataType.OUTPUT + output_logical_shape = mapping["output"].data.view(-1, 32).shape + assert mapping["output"].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # First dimension cannot be shard if input shape is (1, 4, 4, 16) if input_shape != (1, 4, 4, 16): - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S01R = S01R x RR_0' in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') - output_sharding_spec = strategy.get_sharding_spec_by_name('_0') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("weight") + output_sharding_spec = strategy.get_sharding_spec_by_name("_0") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -159,7 +161,6 @@ def check_linear_module_handler(rank, world_size, port, bias, input_shape): class LinearModel(nn.Module): - def __init__(self): super().__init__() @@ -170,7 +171,7 @@ def forward(self, input, others, bias=None): def check_linear_function_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearModel().cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -188,16 +189,18 @@ def check_linear_function_handler(rank, world_size, port, bias, input_shape): # construct input args input_args = [input, other] # construct meta arg names - meta_arg_names = ['input', 'others'] - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=input_args, - meta_arg_names=meta_arg_names) + meta_arg_names = ["input", "others"] + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=input_args, + meta_arg_names=meta_arg_names, + ) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'input': torch.rand(input_shape).to('meta'), 'others': torch.rand(32, 16).to('meta')} + meta_args = {"input": torch.rand(input_shape).to("meta"), "others": torch.rand(32, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -214,86 +217,86 @@ def check_linear_function_handler(rank, world_size, port, bias, input_shape): # # check operation data mapping mapping = handler.get_operation_data_mapping() - assert mapping['input'].name == "input_1" - assert mapping['input'].data.shape == torch.Size(input_shape) - assert mapping['input'].type == OperationDataType.ARG - input_logical_shape = mapping['input'].data.view(-1, 16).shape - assert mapping['input'].logical_shape == torch.Size(input_logical_shape) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.shape == torch.Size(input_shape) + assert mapping["input"].type == OperationDataType.ARG + input_logical_shape = mapping["input"].data.view(-1, 16).shape + assert mapping["input"].logical_shape == torch.Size(input_logical_shape) - assert mapping['other'].name == "others" - assert mapping['other'].data.shape == torch.Size([32, 16]) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["other"].name == "others" + assert mapping["other"].data.shape == torch.Size([32, 16]) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([16, 32]) if bias: - assert mapping['bias'].name == "bias" - assert mapping['bias'].data.shape == torch.Size([32]) - assert mapping['bias'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size([16, 32]) + assert mapping["bias"].name == "bias" + assert mapping["bias"].data.shape == torch.Size([32]) + assert mapping["bias"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size([16, 32]) - assert mapping['output'].name == "linear" + assert mapping["output"].name == "linear" output_shape = input_shape[:-1] + (32,) - assert mapping['output'].data.shape == torch.Size(output_shape) - assert mapping['output'].type == OperationDataType.OUTPUT - output_logical_shape = mapping['output'].data.view(-1, 32).shape - assert mapping['output'].logical_shape == torch.Size(output_logical_shape) + assert mapping["output"].data.shape == torch.Size(output_shape) + assert mapping["output"].type == OperationDataType.OUTPUT + output_logical_shape = mapping["output"].data.view(-1, 32).shape + assert mapping["output"].logical_shape == torch.Size(output_logical_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] # First dimension cannot be shard if input shape is (1, 4, 4, 16) if input_shape != (1, 4, 4, 16): - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S01R = S01R x RR_0' in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list # SS = SR x RS - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') - weight_sharding_spec = strategy.get_sharding_spec_by_name('others') - output_sharding_spec = strategy.get_sharding_spec_by_name('linear') + input_sharding_spec = strategy.get_sharding_spec_by_name("input_1") + weight_sharding_spec = strategy.get_sharding_spec_by_name("others") + output_sharding_spec = strategy.get_sharding_spec_by_name("linear") if bias: - bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') + bias_sharding_spec = strategy.get_sharding_spec_by_name("bias") # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] @@ -304,8 +307,8 @@ def check_linear_function_handler(rank, world_size, port, bias, input_shape): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] -@run_on_environment_flag(name='AUTO_PARALLEL') -@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)]) +@run_on_environment_flag(name="AUTO_PARALLEL") +@parameterize("input_shape", [(1, 4, 4, 16), (4, 4, 4, 16)]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(input_shape, bias=False): @@ -323,5 +326,5 @@ def test_linear_handler(input_shape, bias=False): ) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py index 938acd3d1eea..5fb4985e2f3c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -22,31 +22,31 @@ class MatMulModule(nn.Module): - def forward(self, x1, x2): return torch.matmul(x1, x2) -@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@pytest.mark.skipif(torch.__version__ < "1.12.0", reason="need pytorch 1.12.0 or higher for aten level operations") @clear_cache_before_run() @parameterize( - 'tensor_shapes', + "tensor_shapes", [ - [[8], [8]], # dot product - [[4, 8], [8]], # mat-vec product - [[4, 8], [8, 16]], # mat-mat product - [[8], [8, 16]], # mat-mat product - [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting - [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting - [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting - [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting - [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting - [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting - ]) + [[8], [8]], # dot product + [[4, 8], [8]], # mat-vec product + [[4, 8], [8, 16]], # mat-mat product + [[8], [8, 16]], # mat-mat product + [[8], [4, 8, 16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16]], # batched mat-mat product with padding + broadcasting + [[4, 8, 16], [16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [1, 16, 32]], # batched mat-mat product with broadcasting + [[8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[1, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 1, 8, 16], [2, 4, 16, 32]], # batched mat-mat product with broadcasting + [[2, 4, 8, 16], [2, 4, 16, 32]], # batched mat-mat product without broadcasting + ], +) def test_matmul_node_handler(tensor_shapes): input_shape, other_shape = tensor_shapes @@ -61,7 +61,7 @@ def test_matmul_node_handler(tensor_shapes): model = MatMulModule() tracer = ColoTracer(bias_addition_split=True) - meta_args = {"x1": x1.to('meta'), 'x2': x2.to('meta')} + meta_args = {"x1": x1.to("meta"), "x2": x2.to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -92,30 +92,31 @@ def test_matmul_node_handler(tensor_shapes): logical_input_shape = [1] + input_shape elif matmul_type == MatMulType.BMM: logical_input_shape, logical_other_shape, logical_output_shape = _get_bmm_logical_shape( - input_shape, other_shape, handler.transforms) + input_shape, other_shape, handler.transforms + ) else: logical_input_shape = input_shape # check input operation data - assert mapping['input'].name == "x1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size(input_shape) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size(logical_input_shape) + assert mapping["input"].name == "x1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size(input_shape) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size(logical_input_shape) # check other operation data - assert mapping['other'].name == "x2" - assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size(other_shape) - assert mapping['other'].type == OperationDataType.ARG - assert mapping['other'].logical_shape == torch.Size(logical_other_shape) + assert mapping["other"].name == "x2" + assert mapping["other"].data.is_meta + assert mapping["other"].data.shape == torch.Size(other_shape) + assert mapping["other"].type == OperationDataType.ARG + assert mapping["other"].logical_shape == torch.Size(logical_other_shape) # check output - assert mapping['output'].name == "matmul" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size(output_shape) - assert mapping['output'].type == OperationDataType.OUTPUT - assert mapping['output'].logical_shape == torch.Size(logical_output_shape) + assert mapping["output"].name == "matmul" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size(output_shape) + assert mapping["output"].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == torch.Size(logical_output_shape) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -126,9 +127,9 @@ def test_matmul_node_handler(tensor_shapes): for strategy in strategies_vector: strategy: ShardingStrategy - input_sharding_spec = strategy.get_sharding_spec_by_name('x1') - other_sharding_spec = strategy.get_sharding_spec_by_name('x2') - output_sharding_spec = strategy.get_sharding_spec_by_name('matmul') + input_sharding_spec = strategy.get_sharding_spec_by_name("x1") + other_sharding_spec = strategy.get_sharding_spec_by_name("x2") + output_sharding_spec = strategy.get_sharding_spec_by_name("matmul") if matmul_type == MatMulType.DOT: # dot product will produce a scaler # results should fulfill: @@ -171,5 +172,5 @@ def test_matmul_node_handler(tensor_shapes): assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1] -if __name__ == '__main__': +if __name__ == "__main__": test_matmul_node_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index 6bff9f9648e2..6b7ac766ff18 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -10,16 +10,16 @@ from colossalai.testing import clear_cache_before_run, run_on_environment_flag -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_norm_pool_handler(): - model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) + model = nn.Sequential(nn.MaxPool2d(4, padding=1).to("meta")) tracer = ColoTracer(bias_addition_split=True) # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {}) # return _0 - meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"input": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -41,21 +41,21 @@ def test_norm_pool_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "input_1" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].name == "input_1" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 4, 64, 64]) - assert mapping['output'].name == "_0" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "_0" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4, 16, 16]) + assert mapping["output"].type == OperationDataType.OUTPUT strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 9 -if __name__ == '__main__': +if __name__ == "__main__": test_norm_pool_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index 1703d5ded2f2..4da986181f89 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -12,7 +12,6 @@ class OutputModel(nn.Module): - def __init__(self): super().__init__() @@ -21,8 +20,8 @@ def forward(self, x): return x, y -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') -@parameterize('output_option', ['distributed', 'replicated']) +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") +@parameterize("output_option", ["distributed", "replicated"]) @clear_cache_before_run() def test_output_handler(output_option): model = OutputModel() @@ -31,7 +30,7 @@ def test_output_handler(output_option): # %x : torch.Tensor [#users=2] = placeholder[target=x] # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) # return (x, mul) - meta_args = {'x': torch.rand(4, 4, 64, 64).to('meta')} + meta_args = {"x": torch.rand(4, 4, 64, 64).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -43,10 +42,12 @@ def test_output_handler(output_option): output_strategies_vector = StrategiesVector(output_node) # build handler - output_handler = OutputHandler(node=output_node, - device_mesh=device_mesh, - strategies_vector=output_strategies_vector, - output_option=output_option) + output_handler = OutputHandler( + node=output_node, + device_mesh=device_mesh, + strategies_vector=output_strategies_vector, + output_option=output_option, + ) output_handler.register_strategy(compute_resharding_cost=False) # check operation data mapping @@ -57,14 +58,14 @@ def test_output_handler(output_option): # make sure they have valid values assert op_data.data is not None - assert mapping['output'].name == "output" - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "output" + assert mapping["output"].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in output_handler.strategies_vector] - if output_option == 'distributed': + if output_option == "distributed": assert "Distributed Output" in strategy_name_list else: assert "Replica Output" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_output_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py index f071cd120fb7..958dc288fa16 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -1,5 +1,3 @@ -from functools import partial - import pytest import torch import torch.nn as nn @@ -20,7 +18,6 @@ class ConvReshapeModel(nn.Module): - def __init__(self, reshape_dims, call_function): super().__init__() self.reshape_dims = reshape_dims @@ -37,7 +34,6 @@ def forward(self, input, other): class LinearReshapeModel(nn.Module): - def __init__(self, reshape_dims, call_function): super().__init__() self.reshape_dims = reshape_dims @@ -55,23 +51,23 @@ def forward(self, input, other): def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") if call_function == torch.permute: reshape_dims = reshape_dims[0] elif call_function == torch.transpose: reshape_dims = reshape_dims[1] model = model_cls(reshape_dims, call_function).cuda() - if model_cls.__name__ == 'ConvReshapeModel': - input = torch.rand(8, 8, 66, 66).to('cuda') - other = torch.rand(16, 8, 3, 3).to('cuda') + if model_cls.__name__ == "ConvReshapeModel": + input = torch.rand(8, 8, 66, 66).to("cuda") + other = torch.rand(16, 8, 3, 3).to("cuda") # index of conv node in computation graph node_index = 2 # total number of conv strategies strategy_number = 16 - if model_cls.__name__ == 'LinearReshapeModel': - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + if model_cls.__name__ == "LinearReshapeModel": + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -81,15 +77,17 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) - if model_cls.__name__ == 'ConvReshapeModel': + if model_cls.__name__ == "ConvReshapeModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -97,12 +95,12 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode # %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {}) # return permute meta_args = { - 'input': torch.rand(8, 8, 66, 66).to('meta'), - 'other': torch.rand(16, 8, 3, 3).to('meta'), + "input": torch.rand(8, 8, 66, 66).to("meta"), + "other": torch.rand(16, 8, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) - if model_cls.__name__ == 'LinearReshapeModel': + if model_cls.__name__ == "LinearReshapeModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -110,8 +108,8 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode # %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # return permute meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -124,30 +122,29 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode previous_strategies_vector = StrategiesVector(previous_mod_node) # build handler - if model_cls.__name__ == 'ConvReshapeModel': - - conv_handler = ConvFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + if model_cls.__name__ == "ConvReshapeModel": + conv_handler = ConvFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - if model_cls.__name__ == 'LinearReshapeModel': + if model_cls.__name__ == "LinearReshapeModel": assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) if call_function == torch.permute: - reshape_handler = PermuteHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=view_strategies_vector) + reshape_handler = PermuteHandler( + node=reshape_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector + ) else: - reshape_handler = TransposeHandler(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=view_strategies_vector) + reshape_handler = TransposeHandler( + node=reshape_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector + ) reshape_handler.register_strategy(compute_resharding_cost=False) @@ -159,25 +156,25 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode # make sure they have valid values assert op_data.data is not None - if model_cls.__name__ == 'ConvReshapeModel': - assert mapping['input'].name == "conv2d" + if model_cls.__name__ == "ConvReshapeModel": + assert mapping["input"].name == "conv2d" else: - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) if call_function == torch.permute: - assert mapping['output'].name == "permute" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "permute" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape + assert mapping["output"].type == OperationDataType.OUTPUT else: - assert mapping['output'].name == "transpose" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "transpose" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(view_strategies_vector) == len(previous_strategies_vector) @@ -185,146 +182,144 @@ def check_view_handler(rank, world_size, port, call_function, reshape_dims, mode if rank == 0: for name in strategy_name_list: print(name) - if model_cls.__name__ == 'ConvReshapeModel': - + if model_cls.__name__ == "ConvReshapeModel": if reshape_dims in ((0, 2, 1, 3), (1, 2)): - assert '[S0, S1, R, R] -> [S0, R, S1, R]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [S1, R, S0, R]_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + assert "[S0, S1, R, R] -> [S0, R, S1, R]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [S1, R, S0, R]_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_15" in strategy_name_list if reshape_dims == (2, 0, 1, 3): - assert '[S0, S1, R, R] -> [R, S0, S1, R]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [R, S1, S0, R]_1' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [R, S01, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list + assert "[S0, S1, R, R] -> [R, S0, S1, R]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [R, S1, S0, R]_1" in strategy_name_list + assert "[S0, R, R, R] -> [R, S0, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [R, S1, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [R, S0, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [R, S1, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [R, S01, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_15" in strategy_name_list if reshape_dims == (1, 3): - assert '[S0, S1, R, R] -> [S0, R, R, S1]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [S1, R, R, S0]_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, R, S01]_15' in strategy_name_list - - if model_cls.__name__ == 'LinearReshapeModel': - + assert "[S0, S1, R, R] -> [S0, R, R, S1]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [S1, R, R, S0]_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, R, S1]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, R, S0]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, R, S0]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, R, S01]_15" in strategy_name_list + + if model_cls.__name__ == "LinearReshapeModel": if reshape_dims == ((0, 2, 1, 3), (1, 2)): - assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S0, R, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S1, R, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, S0, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, S1, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, S01, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, R, S0, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, S0, R, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, R, S1, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, S1, R, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, S0, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, S1, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, S01, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list if reshape_dims == (2, 0, 1, 3): - assert '[S0, R, R, S1] -> [R, S0, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [S0, R, R, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [R, S1, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [S1, R, R, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [S0, R, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [S1, R, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R] -> [R, S01, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [S01, R, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert "[S0, R, R, S1] -> [R, S0, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, R, S0, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [S0, R, R, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [R, S1, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, R, S1, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [S1, R, R, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [R, S0, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, S0, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [S0, R, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [R, S1, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, S1, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [S1, R, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R] -> [R, S01, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, S01, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [S01, R, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list if reshape_dims == (1, 3): - assert '[S0, R, R, S1] -> [S0, S1, R, R]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S1, R, S0]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S1, S0, R]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, S0, R, R]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S0, R, S1]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S0, S1, R]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, R, S01]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, S01, R, R]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1] -> [S0, S1, R, R]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S1, R, S0]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, S1, S0, R]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, S0, R, R]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S0, R, S1]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, S0, S1, R]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, R, R, S0]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, R, R, S1]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1, R, R]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0, R, R]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0, R, R]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1, R, R]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, R, R, S01]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, S01, R, R]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('call_function', [torch.permute, torch.transpose]) -@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) -@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) +@parameterize("call_function", [torch.permute, torch.transpose]) +@parameterize("reshape_dims", [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) +@parameterize("model_cls", [ConvReshapeModel, LinearReshapeModel]) def test_view_handler(call_function, reshape_dims, model_cls): spawn( check_view_handler, @@ -335,5 +330,5 @@ def test_view_handler(call_function, reshape_dims, model_cls): ) -if __name__ == '__main__': +if __name__ == "__main__": test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index 6d02b0e0ba74..60c090429c6c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -12,7 +12,6 @@ class PlaceholderModel(nn.Module): - def __init__(self): super().__init__() @@ -20,8 +19,8 @@ def forward(self, input): return input -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') -@parameterize('placeholder_option', ['distributed', 'replicated']) +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") +@parameterize("placeholder_option", ["distributed", "replicated"]) @clear_cache_before_run() def test_placeholder_handler(placeholder_option): model = PlaceholderModel() @@ -30,7 +29,7 @@ def test_placeholder_handler(placeholder_option): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # return input_1 meta_args = { - "input": torch.rand(4, 4, 64, 64).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -42,10 +41,12 @@ def test_placeholder_handler(placeholder_option): placeholder_node = list(graph.nodes)[0] placeholder_strategies_vector = StrategiesVector(placeholder_node) # build handler - placeholder_handler = PlaceholderHandler(node=placeholder_node, - device_mesh=device_mesh, - strategies_vector=placeholder_strategies_vector, - placeholder_option=placeholder_option) + placeholder_handler = PlaceholderHandler( + node=placeholder_node, + device_mesh=device_mesh, + strategies_vector=placeholder_strategies_vector, + placeholder_option=placeholder_option, + ) placeholder_handler.register_strategy(compute_resharding_cost=False) @@ -53,28 +54,28 @@ def test_placeholder_handler(placeholder_option): mapping = placeholder_handler.get_operation_data_mapping() strategy = placeholder_strategies_vector[0] - strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping['output'].name) + strategy_sharding_spec = strategy.get_sharding_spec_by_name(mapping["output"].name) - if placeholder_option == 'distributed': - assert str(strategy_sharding_spec.sharding_sequence) == '[S01, R, R, R]' + if placeholder_option == "distributed": + assert str(strategy_sharding_spec.sharding_sequence) == "[S01, R, R, R]" else: - assert str(strategy_sharding_spec.sharding_sequence) == '[R, R, R, R]' + assert str(strategy_sharding_spec.sharding_sequence) == "[R, R, R, R]" for name, op_data in mapping.items(): op_data: OperationData # make sure they have valid values assert op_data.data is not None - assert mapping['output'].name == "input_1" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size((4, 4, 64, 64)) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "input_1" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size((4, 4, 64, 64)) + assert mapping["output"].type == OperationDataType.OUTPUT strategy_name_list = [val.name for val in placeholder_handler.strategies_vector] - if placeholder_option == 'replicated': + if placeholder_option == "replicated": assert "Replica Placeholder" in strategy_name_list else: assert "Distributed Placeholder" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_placeholder_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py index 14c364c45fc4..6836a882242f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py @@ -12,7 +12,6 @@ class LinearModel(nn.Module): - def __init__(self): super().__init__() @@ -28,7 +27,7 @@ def check_shard_option(shard_option): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) tracer = ColoTracer(bias_addition_split=True) - meta_args = {'input': torch.rand(4, 4, 4, 16).to('meta'), 'others': torch.rand(32, 16).to('meta')} + meta_args = {"input": torch.rand(4, 4, 4, 16).to("meta"), "others": torch.rand(32, 16).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -36,77 +35,76 @@ def check_shard_option(shard_option): strategies_vector = StrategiesVector(linear_func_node) # build handler - handler = LinearFunctionHandler(node=linear_func_node, - device_mesh=device_mesh, - strategies_vector=strategies_vector, - shard_option=shard_option) + handler = LinearFunctionHandler( + node=linear_func_node, device_mesh=device_mesh, strategies_vector=strategies_vector, shard_option=shard_option + ) strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] if shard_option == ShardOption.SHARD_LAST_AXIS: # RR = RS x SR - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list return # SS = SR x RS - assert 'S1S0 = S1R x RS0_0' in strategy_name_list - assert 'S0S1 = S0R x RS1_1' in strategy_name_list - assert 'S0S1 = S0R x RS1_2' in strategy_name_list - assert 'S0S1 = S0R x RS1_0' in strategy_name_list - assert 'S1S0 = S1R x RS0_1' in strategy_name_list - assert 'S1S0 = S1R x RS0_2' in strategy_name_list + assert "S1S0 = S1R x RS0_0" in strategy_name_list + assert "S0S1 = S0R x RS1_1" in strategy_name_list + assert "S0S1 = S0R x RS1_2" in strategy_name_list + assert "S0S1 = S0R x RS1_0" in strategy_name_list + assert "S1S0 = S1R x RS0_1" in strategy_name_list + assert "S1S0 = S1R x RS0_2" in strategy_name_list # SR = SS x SR - assert 'S0R = S0S1 x S1R_1' in strategy_name_list - assert 'S0R = S0S1 x S1R_2' in strategy_name_list - assert 'S1R = S1S0 x S0R_0' in strategy_name_list - assert 'S0R = S0S1 x S1R_0' in strategy_name_list - assert 'S1R = S1S0 x S0R_1' in strategy_name_list - assert 'S1R = S1S0 x S0R_2' in strategy_name_list + assert "S0R = S0S1 x S1R_1" in strategy_name_list + assert "S0R = S0S1 x S1R_2" in strategy_name_list + assert "S1R = S1S0 x S0R_0" in strategy_name_list + assert "S0R = S0S1 x S1R_0" in strategy_name_list + assert "S1R = S1S0 x S0R_1" in strategy_name_list + assert "S1R = S1S0 x S0R_2" in strategy_name_list # RS = RS x SS - assert 'RS0 = RS1 x S1S0' in strategy_name_list - assert 'RS1 = RS0 x S0S1' in strategy_name_list + assert "RS0 = RS1 x S1S0" in strategy_name_list + assert "RS1 = RS0 x S0S1" in strategy_name_list # S01R = S01R x RR - assert 'S01R = S01R x RR_0' in strategy_name_list - assert 'S01R = S01R x RR_1' in strategy_name_list - assert 'S01R = S01R x RR_2' in strategy_name_list + assert "S01R = S01R x RR_0" in strategy_name_list + assert "S01R = S01R x RR_1" in strategy_name_list + assert "S01R = S01R x RR_2" in strategy_name_list # RR = RS01 x S01R - assert 'RR = RS01 x S01R' in strategy_name_list + assert "RR = RS01 x S01R" in strategy_name_list # RS01 = RR x RS01 - assert 'RS01 = RR x RS01' in strategy_name_list + assert "RS01 = RR x RS01" in strategy_name_list if shard_option == ShardOption.SHARD: # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list if shard_option == ShardOption.STANDARD: # RR = RS x SR - assert 'RR = RS0 x S0R' in strategy_name_list - assert 'RR = RS1 x S1R' in strategy_name_list + assert "RR = RS0 x S0R" in strategy_name_list + assert "RR = RS1 x S1R" in strategy_name_list # RS= RR x RS - assert 'RS0 = RR x RS0' in strategy_name_list - assert 'RS1 = RR x RS1' in strategy_name_list + assert "RS0 = RR x RS0" in strategy_name_list + assert "RS1 = RR x RS1" in strategy_name_list # RR = RR x RR - assert 'RR = RR x RR' in strategy_name_list + assert "RR = RR x RR" in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_shard_option(): # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: @@ -114,5 +112,5 @@ def test_shard_option(): check_shard_option(shard_option) -if __name__ == '__main__': +if __name__ == "__main__": test_shard_option() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py index 75ae0416ef98..1a99c32ebcb9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -17,7 +17,6 @@ class LinearSplitModel(nn.Module): - def __init__(self, softmax_dim): super().__init__() self.softmax_dim = softmax_dim @@ -30,11 +29,11 @@ def forward(self, input, other): def check_split_handler(rank, world_size, port, softmax_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(softmax_dim=softmax_dim).cuda() - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -44,13 +43,15 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) # graph(): @@ -60,8 +61,8 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): # %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # return split meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -75,15 +76,15 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): # build handler assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - softmax_handler = SoftmaxHandler(node=split_node, - device_mesh=device_mesh, - strategies_vector=split_strategies_vector) + softmax_handler = SoftmaxHandler( + node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector + ) softmax_handler.register_strategy(compute_resharding_cost=False) @@ -95,84 +96,84 @@ def check_split_handler(rank, world_size, port, softmax_dim, model_cls): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['softmax_dim'].name == "softmax_dim" - assert mapping['softmax_dim'].data == softmax_dim - assert mapping['softmax_dim'].type == OperationDataType.ARG + assert mapping["softmax_dim"].name == "softmax_dim" + assert mapping["softmax_dim"].data == softmax_dim + assert mapping["softmax_dim"].type == OperationDataType.ARG - assert mapping['output'].name == "softmax" - assert mapping['output'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "softmax" + assert mapping["output"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["output"].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(split_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in split_strategies_vector] if softmax_dim == 0: - assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S0, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1]_13" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S1, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0]_16" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list if softmax_dim == 1: - assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1] -> [S0, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('softmax_dim', [0, 1, 2, 3]) -@parameterize('model_cls', [LinearSplitModel]) +@parameterize("softmax_dim", [0, 1, 2, 3]) +@parameterize("model_cls", [LinearSplitModel]) def test_split_handler(softmax_dim, model_cls): spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py index f860c629b0a0..0318023c858d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -17,7 +17,6 @@ class ConvSplitModel(nn.Module): - def __init__(self, split_size, split_dim): super().__init__() self.split_size = split_size @@ -30,7 +29,6 @@ def forward(self, input, other): class LinearSplitModel(nn.Module): - def __init__(self, split_size, split_dim): super().__init__() self.split_size = split_size @@ -44,19 +42,19 @@ def forward(self, input, other): def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(split_size=split_size, split_dim=split_dim).cuda() - if model_cls.__name__ == 'ConvSplitModel': - input = torch.rand(8, 8, 66, 66).to('cuda') - other = torch.rand(16, 8, 3, 3).to('cuda') + if model_cls.__name__ == "ConvSplitModel": + input = torch.rand(8, 8, 66, 66).to("cuda") + other = torch.rand(16, 8, 3, 3).to("cuda") # index of conv node in computation graph node_index = 2 # total number of conv strategies strategy_number = 16 - if model_cls.__name__ == 'LinearSplitModel': - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + if model_cls.__name__ == "LinearSplitModel": + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -66,15 +64,17 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) - if model_cls.__name__ == 'ConvSplitModel': + if model_cls.__name__ == "ConvSplitModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -82,12 +82,12 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls # %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {}) # return split meta_args = { - 'input': torch.rand(8, 8, 66, 66).to('meta'), - 'other': torch.rand(16, 8, 3, 3).to('meta'), + "input": torch.rand(8, 8, 66, 66).to("meta"), + "other": torch.rand(16, 8, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) - if model_cls.__name__ == 'LinearSplitModel': + if model_cls.__name__ == "LinearSplitModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -95,8 +95,8 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls # %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {}) # return split meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -109,21 +109,20 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls previous_strategies_vector = StrategiesVector(previous_mod_node) # build handler - if model_cls.__name__ == 'ConvSplitModel': - - conv_handler = ConvFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + if model_cls.__name__ == "ConvSplitModel": + conv_handler = ConvFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - if model_cls.__name__ == 'LinearSplitModel': + if model_cls.__name__ == "LinearSplitModel": assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) split_handler = SplitHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector) @@ -137,124 +136,122 @@ def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls # make sure they have valid values assert op_data.data is not None - if model_cls.__name__ == 'ConvSplitModel': - assert mapping['input'].name == "conv2d" + if model_cls.__name__ == "ConvSplitModel": + assert mapping["input"].name == "conv2d" else: - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].name == "split" + assert mapping["output"].name == "split" split_items = torch.empty([8, 16, 64, 64]).split(split_size, split_dim) - assert mapping['output'].logical_shape == tuple([item.shape for item in split_items]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == tuple([item.shape for item in split_items]) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(split_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in split_strategies_vector] - if model_cls.__name__ == 'ConvSplitModel': - + if model_cls.__name__ == "ConvSplitModel": if split_dim == 0: - assert '[R, S1, R, R]_0' in strategy_name_list - assert '[R, S0, R, R]_1' in strategy_name_list - assert '[R, R, R, R]_2' in strategy_name_list - assert '[R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, R]_4' in strategy_name_list - assert '[R, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R]_6' in strategy_name_list - assert '[R, S0, R, R]_7' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R]_10' in strategy_name_list - assert '[R, S1, R, R]_11' in strategy_name_list - assert '[R, R, R, R]_12' in strategy_name_list - assert '[R, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R]_15' in strategy_name_list + assert "[R, S1, R, R]_0" in strategy_name_list + assert "[R, S0, R, R]_1" in strategy_name_list + assert "[R, R, R, R]_2" in strategy_name_list + assert "[R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, R]_4" in strategy_name_list + assert "[R, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R]_6" in strategy_name_list + assert "[R, S0, R, R]_7" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R]_10" in strategy_name_list + assert "[R, S1, R, R]_11" in strategy_name_list + assert "[R, R, R, R]_12" in strategy_name_list + assert "[R, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R]_15" in strategy_name_list if split_dim == 1: - assert '[S0, R, R, R]_0' in strategy_name_list - assert '[S1, R, R, R]_1' in strategy_name_list - assert '[S0, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R]_5' in strategy_name_list - assert '[R, R, R, R]_6' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_9' in strategy_name_list - assert '[R, R, R, R]_10' in strategy_name_list - assert '[R, R, R, R]_11' in strategy_name_list - assert '[R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R]_15' in strategy_name_list - - if model_cls.__name__ == 'LinearSplitModel': - + assert "[S0, R, R, R]_0" in strategy_name_list + assert "[S1, R, R, R]_1" in strategy_name_list + assert "[S0, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R]_5" in strategy_name_list + assert "[R, R, R, R]_6" in strategy_name_list + assert "[R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_9" in strategy_name_list + assert "[R, R, R, R]_10" in strategy_name_list + assert "[R, R, R, R]_11" in strategy_name_list + assert "[R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R]_14" in strategy_name_list + assert "[R, R, R, R]_15" in strategy_name_list + + if model_cls.__name__ == "LinearSplitModel": if split_dim == 0: - assert '[R, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1]_13' in strategy_name_list - assert '[R, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0]_16' in strategy_name_list - assert '[R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R]_18' in strategy_name_list - assert '[R, R, S0, R]_19' in strategy_name_list - assert '[R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R]_21' in strategy_name_list - assert '[R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R]_1' in strategy_name_list - assert '[R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01]_4' in strategy_name_list + assert "[R, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1]_13" in strategy_name_list + assert "[R, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0]_16" in strategy_name_list + assert "[R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R]_18" in strategy_name_list + assert "[R, R, S0, R]_19" in strategy_name_list + assert "[R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R]_21" in strategy_name_list + assert "[R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R]_1" in strategy_name_list + assert "[R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01]_4" in strategy_name_list if split_dim == 1: - assert '[S0, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R]_17' in strategy_name_list - assert '[R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, S1]_5' in strategy_name_list - assert '[S01, R, R, R]_0' in strategy_name_list - assert '[R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R]_17" in strategy_name_list + assert "[R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R]_20" in strategy_name_list + assert "[R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, S1]_5" in strategy_name_list + assert "[S01, R, R, R]_0" in strategy_name_list + assert "[R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('split_size', [2]) -@parameterize('split_dim', [0, 1, 2]) -@parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) +@parameterize("split_size", [2]) +@parameterize("split_dim", [0, 1, 2]) +@parameterize("model_cls", [ConvSplitModel, LinearSplitModel]) def test_split_handler(split_size, split_dim, model_cls): spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_split_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py index c11291ecac96..cbd3e47044b3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -16,7 +16,6 @@ class LinearSumModel(nn.Module): - def __init__(self, sum_dims, keepdim): super().__init__() self.sum_dims = sum_dims @@ -33,26 +32,28 @@ def forward(self, input, other): def check_sum_handler(rank, world_size, port, sum_dims, keepdim): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies strategy_number = 24 - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) @@ -63,8 +64,8 @@ def check_sum_handler(rank, world_size, port, sum_dims, keepdim): # %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {}) # return sum_1 meta_args = { - "input": torch.rand(8, 16, 64, 32).to('meta'), - "other": torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -78,11 +79,11 @@ def check_sum_handler(rank, world_size, port, sum_dims, keepdim): # build handler assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) sum_handler = SumHandler(node=sum_node, device_mesh=device_mesh, strategies_vector=sum_strategies_vector) @@ -100,131 +101,131 @@ def check_sum_handler(rank, world_size, port, sum_dims, keepdim): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].name == "sum_1" + assert mapping["output"].name == "sum_1" sum_node_shape = torch.empty([8, 16, 64, 64]).sum(sum_dims, keepdim=keepdim).shape - assert mapping['output'].logical_shape == sum_node_shape - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].logical_shape == sum_node_shape + assert mapping["output"].type == OperationDataType.OUTPUT # check strategy name if sum_dims == (0, 2) and keepdim == False: - assert '[R, R, R, R] -> [R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [S01, R]_1' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_10' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [S0, S1]_12' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1]_13' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [S1, S0]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [S0, R]_18' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_19' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [S1, R]_21' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list + assert "[R, R, R, R] -> [R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [S01, R]_1" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_10" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [S0, S1]_12" in strategy_name_list + assert "[R, R, R, S1] -> [R, S1]_13" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [S1, S0]_15" in strategy_name_list + assert "[R, R, R, S0] -> [R, S0]_16" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [S0, R]_18" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_19" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [S1, R]_21" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R]_23" in strategy_name_list if sum_dims == (0, 2) and keepdim == True: - assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_13' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R]_1" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S0, R, S1]_12" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_13" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S1, R, S0]_15" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_16" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R]_18" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_19" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R]_21" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_23" in strategy_name_list if sum_dims == 1 and keepdim == False: - assert '[S01, R, R, R] -> [S01, R, R]_0' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, S1]_10' in strategy_name_list - assert '[S0, R, R, S1] -> [S0, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, S1, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R]_0" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, S1]_10" in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R]_17" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R]_20" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, S1, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R]_23" in strategy_name_list if sum_dims == 1 and keepdim == True: - assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list - assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S01, R, R, R] -> [S01, R, R, R]_0" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01]_4" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_5" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_6" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_9" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_10" in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, R, S1]_11" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0]_14" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R]_17" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R]_20" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R]_22" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R]_23" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('sum_dims', [(0, 2), 1]) -@parameterize('keepdim', [False, True]) +@parameterize("sum_dims", [(0, 2), 1]) +@parameterize("keepdim", [False, True]) def test_sum_handler(sum_dims, keepdim): spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim) -if __name__ == '__main__': +if __name__ == "__main__": test_sum_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py index 5b6ac051a8ef..29089183165d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py @@ -11,7 +11,6 @@ class TensorConstructorModel(nn.Module): - def __init__(self): super().__init__() @@ -21,7 +20,7 @@ def forward(self, x): return x -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_where_handler(): model = TensorConstructorModel() @@ -33,7 +32,7 @@ def test_where_handler(): # %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {}) # %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {}) # return add - meta_args = {'x': torch.rand(10).to('meta')} + meta_args = {"x": torch.rand(10).to("meta")} graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) shape_prop_pass(gm, *meta_args.values()) @@ -56,16 +55,16 @@ def test_where_handler(): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['output'].name == "arange" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([10]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "arange" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([10]) + assert mapping["output"].type == OperationDataType.OUTPUT handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] - assert 'Replica Tensor Constructor' in strategy_name_list + assert "Replica Tensor Constructor" in strategy_name_list -if __name__ == '__main__': +if __name__ == "__main__": test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index f4e6dafdfd69..271d55ae917a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -12,7 +12,6 @@ class ReLuModel(nn.Module): - def __init__(self): super().__init__() self.act = torch.nn.ReLU() @@ -23,7 +22,7 @@ def forward(self, input, other): return relu_node -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_elementwise_handler(): model = ReLuModel() @@ -35,8 +34,8 @@ def test_elementwise_handler(): # %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {}) # return act meta_args = { - 'input': torch.rand(4, 4, 64, 64).to('meta'), - 'other': torch.rand(16, 4, 3, 3).to('meta'), + "input": torch.rand(4, 4, 64, 64).to("meta"), + "other": torch.rand(16, 4, 3, 3).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -51,14 +50,14 @@ def test_elementwise_handler(): conv_strategies_vector = StrategiesVector(conv_mod_node) # build handler - conv_handler = ConvFunctionHandler(node=conv_mod_node, - device_mesh=device_mesh, - strategies_vector=conv_strategies_vector) + conv_handler = ConvFunctionHandler( + node=conv_mod_node, device_mesh=device_mesh, strategies_vector=conv_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - relu_handler = UnaryElementwiseHandler(node=relu_mod_node, - device_mesh=device_mesh, - strategies_vector=relu_strategies_vector) + setattr(conv_mod_node, "strategies_vector", conv_strategies_vector) + relu_handler = UnaryElementwiseHandler( + node=relu_mod_node, device_mesh=device_mesh, strategies_vector=relu_strategies_vector + ) relu_handler.register_strategy(compute_resharding_cost=False) @@ -70,20 +69,20 @@ def test_elementwise_handler(): # make sure they have valid values assert op_data.data is not None - assert mapping['input'].name == "conv2d" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].name == "conv2d" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([4, 16, 62, 62]) - assert mapping['output'].name == "act" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 16, 62, 62]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "act" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 16, 62, 62]) + assert mapping["output"].type == OperationDataType.OUTPUT # getitem is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(relu_strategies_vector) == len(conv_strategies_vector) -if __name__ == '__main__': +if __name__ == "__main__": test_elementwise_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index fbb194d8e0b8..466168c79a0b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -18,7 +18,6 @@ class ConvViewModel(nn.Module): - def __init__(self, tgt_shape): super().__init__() self.tgt_shape = tgt_shape @@ -30,7 +29,6 @@ def forward(self, input, other): class LinearViewModel(nn.Module): - def __init__(self, tgt_shape): super().__init__() self.tgt_shape = tgt_shape @@ -43,19 +41,19 @@ def forward(self, input, other): def check_view_handler(rank, tgt_shape, model_cls, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") model = model_cls(tgt_shape).cuda() - if model_cls.__name__ == 'ConvViewModel': - input = torch.rand(8, 8, 66, 66).to('cuda') - other = torch.rand(16, 8, 3, 3).to('cuda') + if model_cls.__name__ == "ConvViewModel": + input = torch.rand(8, 8, 66, 66).to("cuda") + other = torch.rand(16, 8, 3, 3).to("cuda") # index of conv node in computation graph node_index = 2 # total number of conv strategies strategy_number = 16 - if model_cls.__name__ == 'LinearViewModel': - input = torch.rand(8, 16, 64, 32).to('cuda') - other = torch.rand(64, 32).to('cuda') + if model_cls.__name__ == "LinearViewModel": + input = torch.rand(8, 16, 64, 32).to("cuda") + other = torch.rand(64, 32).to("cuda") # index of linear node in computation graph node_index = 2 # total number of linear strategies @@ -65,25 +63,27 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - numerical_test_for_node_strategy(model=model, - device_mesh=device_mesh, - node_index=node_index, - strategy_number=strategy_number, - input_args=[input, other], - meta_arg_names=['input', 'other'], - node_type='following') + numerical_test_for_node_strategy( + model=model, + device_mesh=device_mesh, + node_index=node_index, + strategy_number=strategy_number, + input_args=[input, other], + meta_arg_names=["input", "other"], + node_type="following", + ) tracer = ColoTracer(bias_addition_split=True) - if model_cls.__name__ == 'ConvViewModel': + if model_cls.__name__ == "ConvViewModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) # return view - meta_args = {'input': torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta')} + meta_args = {"input": torch.rand(8, 8, 66, 66).to("meta"), "other": torch.rand(16, 8, 3, 3).to("meta")} graph = tracer.trace(model, meta_args=meta_args) - if model_cls.__name__ == 'LinearViewModel': + if model_cls.__name__ == "LinearViewModel": # graph(): # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] # %other : torch.Tensor [#users=1] = placeholder[target=other] @@ -91,8 +91,8 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): # %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {}) # return view meta_args = { - 'input': torch.rand(8, 16, 64, 32).to('meta'), - 'other': torch.rand(64, 32).to('meta'), + "input": torch.rand(8, 16, 64, 32).to("meta"), + "other": torch.rand(64, 32).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) @@ -105,21 +105,20 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): previous_strategies_vector = StrategiesVector(previous_mod_node) # build handler - if model_cls.__name__ == 'ConvViewModel': - - conv_handler = ConvFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + if model_cls.__name__ == "ConvViewModel": + conv_handler = ConvFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) conv_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) - if model_cls.__name__ == 'LinearViewModel': + if model_cls.__name__ == "LinearViewModel": assert len(previous_strategies_vector) == 0 - linear_handler = LinearFunctionHandler(node=previous_mod_node, - device_mesh=device_mesh, - strategies_vector=previous_strategies_vector) + linear_handler = LinearFunctionHandler( + node=previous_mod_node, device_mesh=device_mesh, strategies_vector=previous_strategies_vector + ) linear_handler.register_strategy(compute_resharding_cost=False) - setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector) + setattr(previous_mod_node, "strategies_vector", previous_strategies_vector) view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector) @@ -133,126 +132,124 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): # make sure they have valid values assert op_data.data is not None - if model_cls.__name__ == 'ConvViewModel': - assert mapping['input'].name == "conv2d" + if model_cls.__name__ == "ConvViewModel": + assert mapping["input"].name == "conv2d" else: - assert mapping['input'].name == "linear" - assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) - assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].name == "linear" + assert mapping["input"].data.is_meta + assert mapping["input"].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping["input"].type == OperationDataType.ARG + assert mapping["input"].logical_shape == torch.Size([8, 16, 64, 64]) - assert mapping['output'].name == "view" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size(tgt_shape) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["output"].name == "view" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size(tgt_shape) + assert mapping["output"].type == OperationDataType.OUTPUT # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. assert len(view_strategies_vector) == len(previous_strategies_vector) strategy_name_list = [strategy.name for strategy in view_strategies_vector] - if model_cls.__name__ == 'ConvViewModel': - + if model_cls.__name__ == "ConvViewModel": if tgt_shape == (32, 4, 64, 16, 4): - assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list - assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list + assert "[S0, S1, R, R] -> FULLY REPLICATED_0" in strategy_name_list + assert "[S1, S0, R, R] -> FULLY REPLICATED_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> FULLY REPLICATED_6" in strategy_name_list + assert "[R, S0, R, R] -> FULLY REPLICATED_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> FULLY REPLICATED_10" in strategy_name_list + assert "[R, S1, R, R] -> FULLY REPLICATED_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> FULLY REPLICATED_15" in strategy_name_list if tgt_shape == (8, 4, 4, 64, 16, 4): - assert '[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0' in strategy_name_list - assert '[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_2' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_3' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_4' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_5' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_9' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_10' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_11' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_12' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_15' in strategy_name_list - - if model_cls.__name__ == 'LinearViewModel': - + assert "[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0" in strategy_name_list + assert "[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R, R]_2" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R, R]_3" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R, R]_4" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R, R]_5" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R, R, R]_6" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_9" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R, R, R]_10" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R, R, R]_11" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_12" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R, R]_13" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_14" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R, R, R]_15" in strategy_name_list + + if model_cls.__name__ == "LinearViewModel": if tgt_shape == (32, 4, 64, 16, 4): for strategy in strategy_name_list: print(strategy) # print(strategy_name_list) - assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_11' in strategy_name_list - assert '[R, S0, R, S1] -> FULLY REPLICATED_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_14' in strategy_name_list - assert '[R, S1, R, S0] -> FULLY REPLICATED_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> FULLY REPLICATED_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01, R]_4' in strategy_name_list + assert "[S0, R, R, S1] -> [S0, R, R, S1, R]_11" in strategy_name_list + assert "[R, S0, R, S1] -> FULLY REPLICATED_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, S0, S1, R]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, S0, R]_14" in strategy_name_list + assert "[R, S1, R, S0] -> FULLY REPLICATED_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, S1, S0, R]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> FULLY REPLICATED_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, S0, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> FULLY REPLICATED_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, S1, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1, R]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0, R]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, S0, R]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, S1, R]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> FULLY REPLICATED_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, S01, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, S01, R]_4" in strategy_name_list if tgt_shape == (8, 4, 4, 64, 16, 4): - assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_17' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_18' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_19' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_20' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_21' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_22' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_10' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_7' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_6' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_5' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_0' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_1' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_2' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_3' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_4' in strategy_name_list - - -@run_on_environment_flag(name='AUTO_PARALLEL') + assert "[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11" in strategy_name_list + assert "[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12" in strategy_name_list + assert "[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13" in strategy_name_list + assert "[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14" in strategy_name_list + assert "[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15" in strategy_name_list + assert "[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16" in strategy_name_list + assert "[S0, R, R, R] -> [S0, R, R, R, R, R]_17" in strategy_name_list + assert "[R, S0, R, R] -> [R, S0, R, R, R, R]_18" in strategy_name_list + assert "[R, R, S0, R] -> [R, R, R, S0, R, R]_19" in strategy_name_list + assert "[S1, R, R, R] -> [S1, R, R, R, R, R]_20" in strategy_name_list + assert "[R, S1, R, R] -> [R, S1, R, R, R, R]_21" in strategy_name_list + assert "[R, R, S1, R] -> [R, R, R, S1, R, R]_22" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, R, S1, R]_10" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, R, S0, R]_9" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_8" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_7" in strategy_name_list + assert "[R, R, R, S0] -> [R, R, R, R, S0, R]_6" in strategy_name_list + assert "[R, R, R, S1] -> [R, R, R, R, S1, R]_5" in strategy_name_list + assert "[S01, R, R, R] -> [S01, R, R, R, R, R]_0" in strategy_name_list + assert "[R, S01, R, R] -> [R, S01, R, R, R, R]_1" in strategy_name_list + assert "[R, R, S01, R] -> [R, R, R, S01, R, R]_2" in strategy_name_list + assert "[R, R, R, R] -> [R, R, R, R, R, R]_3" in strategy_name_list + assert "[R, R, R, S01] -> [R, R, R, R, S01, R]_4" in strategy_name_list + + +@run_on_environment_flag(name="AUTO_PARALLEL") @pytest.mark.dist @rerun_if_address_is_in_use() -@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) -@parameterize('model_cls', [ConvViewModel, LinearViewModel]) +@parameterize("tgt_shape", [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) +@parameterize("model_cls", [ConvViewModel, LinearViewModel]) def test_view_handler(tgt_shape, model_cls): spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls) -if __name__ == '__main__': +if __name__ == "__main__": test_view_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py index bd7635ac1737..10ca644cddc2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -12,7 +12,6 @@ class ConvModel(nn.Module): - def __init__(self): super().__init__() @@ -21,7 +20,7 @@ def forward(self, condition, x, y): return output -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") @clear_cache_before_run() def test_where_handler(): model = ConvModel() @@ -33,9 +32,9 @@ def test_where_handler(): # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) # return where meta_args = { - 'condition': torch.rand(4, 4, 64, 64).to('meta'), - 'x': torch.rand(4, 1, 64, 64).to('meta'), - 'y': torch.rand(1, 4, 64, 64).to('meta') + "condition": torch.rand(4, 4, 64, 64).to("meta"), + "x": torch.rand(4, 1, 64, 64).to("meta"), + "y": torch.rand(1, 4, 64, 64).to("meta"), } graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(model, graph) @@ -59,28 +58,28 @@ def test_where_handler(): assert op_data.logical_shape is not None assert op_data.data is not None - assert mapping['condition'].name == "condition" - assert mapping['condition'].data.is_meta - assert mapping['condition'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['condition'].type == OperationDataType.ARG - assert mapping['condition'].logical_shape == torch.Size([4, 4, 64, 64]) - - assert mapping['x'].name == "x" - assert mapping['x'].data.is_meta - assert mapping['x'].data.shape == torch.Size([4, 1, 64, 64]) - assert mapping['x'].type == OperationDataType.ARG - assert mapping['x'].logical_shape == torch.Size([4, 4, 64, 64]) - - assert mapping['y'].name == "y" - assert mapping['y'].data.is_meta - assert mapping['y'].data.shape == torch.Size([1, 4, 64, 64]) - assert mapping['y'].type == OperationDataType.ARG - assert mapping['y'].logical_shape == torch.Size([4, 4, 64, 64]) - - assert mapping['output'].name == "where" - assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 4, 64, 64]) - assert mapping['output'].type == OperationDataType.OUTPUT + assert mapping["condition"].name == "condition" + assert mapping["condition"].data.is_meta + assert mapping["condition"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["condition"].type == OperationDataType.ARG + assert mapping["condition"].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping["x"].name == "x" + assert mapping["x"].data.is_meta + assert mapping["x"].data.shape == torch.Size([4, 1, 64, 64]) + assert mapping["x"].type == OperationDataType.ARG + assert mapping["x"].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping["y"].name == "y" + assert mapping["y"].data.is_meta + assert mapping["y"].data.shape == torch.Size([1, 4, 64, 64]) + assert mapping["y"].type == OperationDataType.ARG + assert mapping["y"].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping["output"].name == "where" + assert mapping["output"].data.is_meta + assert mapping["output"].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping["output"].type == OperationDataType.OUTPUT handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] @@ -88,5 +87,5 @@ def test_where_handler(): assert len(strategy_name_list) == 25 -if __name__ == '__main__': +if __name__ == "__main__": test_where_handler() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py index 28a8bbd9a4c1..3591c663897c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py @@ -2,7 +2,6 @@ from typing import Dict, List import torch -from torch.fx import GraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass @@ -18,16 +17,18 @@ from colossalai.testing.comparison import assert_close -def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor], - input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]): - +def _build_model_to_compare( + model: torch.nn.Module, + input_args: List[torch.Tensor], + input_kwargs: Dict[str, torch.Tensor], + grad_dict: Dict[any, torch.Tensor], +): model_to_compare = copy.deepcopy(model) args_to_compare = [] kwargs_to_compare = {} for arg_index, input_tensor in enumerate(input_args): def wrapper(param, index): - def hook_fn(grad): grad_dict[index] = grad @@ -45,7 +46,6 @@ def hook_fn(grad): for name, input_kwarg in input_kwargs.items(): def wrapper(param, name): - def hook_fn(grad): grad_dict[name] = grad @@ -63,30 +63,34 @@ def hook_fn(grad): return model_to_compare, args_to_compare, kwargs_to_compare -def numerical_test_for_node_strategy(model: torch.nn.Module, - device_mesh: DeviceMesh, - node_index: int, - strategy_number: int, - input_args: List[torch.Tensor], - meta_arg_names: List[str], - input_kwargs: Dict[str, torch.Tensor] = {}, - node_type: str = 'normal'): +def numerical_test_for_node_strategy( + model: torch.nn.Module, + device_mesh: DeviceMesh, + node_index: int, + strategy_number: int, + input_args: List[torch.Tensor], + meta_arg_names: List[str], + input_kwargs: Dict[str, torch.Tensor] = {}, + node_type: str = "normal", +): for strategy_index in range(strategy_number): - print(f'#strategy_index: {strategy_index}') + print(f"#strategy_index: {strategy_index}") # We need to copy the model to avoid do backward more than once in same graph grad_to_compare_dict = {} grad_to_shard_dict = {} model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare( - model, input_args, input_kwargs, grad_to_compare_dict) - model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs, - grad_to_shard_dict) + model, input_args, input_kwargs, grad_to_compare_dict + ) + model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare( + model, input_args, input_kwargs, grad_to_shard_dict + ) tracer = ColoTracer(bias_addition_split=True) input_sample = {} for input_arg, meta_arg_name in zip(input_args, meta_arg_names): - input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to('meta') + input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to("meta") for meta_kwarg_name, input_kwarg in input_kwargs.items(): - input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to('meta') + input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to("meta") graph = tracer.trace(root=model_to_shard, meta_args=input_sample) gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) shape_prop_pass(gm, *input_sample.values()) @@ -94,13 +98,14 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() - target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies - ][node_index] - if node_type == 'normal': + target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies][ + node_index + ] + if node_type == "normal": solution_len = len(strategies_constructor.leaf_strategies) solution = [0] * solution_len solution[node_index] = strategy_index - elif node_type == 'following': + elif node_type == "following": solution_len = len(strategies_constructor.leaf_strategies) solution = [0] * solution_len solution[node_index] = strategy_index @@ -116,18 +121,21 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, ret = solver.call_solver_serialized_args() solution = list(ret[0]) gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( - gm, solution, device_mesh, strategies_constructor) + gm, solution, device_mesh, strategies_constructor + ) gm = runtime_apply_pass(gm) gm.recompile() # forward result compare - output = gm(*args_to_shard, - sharding_spec_convert_dict=sharding_spec_dict, - origin_node_sharding_spec_dict=origin_spec_dict, - comm_actions_dict=comm_actions_dict, - **kwargs_to_shard) + output = gm( + *args_to_shard, + sharding_spec_convert_dict=sharding_spec_dict, + origin_node_sharding_spec_dict=origin_spec_dict, + comm_actions_dict=comm_actions_dict, + **kwargs_to_shard, + ) output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare) - assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type='forward output') + assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type="forward output") # backward result compare if isinstance(output, (tuple, list)): @@ -142,43 +150,45 @@ def numerical_test_for_node_strategy(model: torch.nn.Module, for key in grad_to_shard_dict.keys(): grad_to_shard = grad_to_shard_dict[key] grad_to_compare = grad_to_compare_dict[key] - assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type='input grad') + assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type="input grad") # extract the strategy used in this iter strategy_in_use = target_node.strategies_vector[strategy_index] param_to_shard_dict = dict(gm.named_parameters()) param_to_compare_dict = dict(model_to_compare.named_parameters()) for name in param_to_shard_dict.keys(): - param_name = name.split('.')[-1] - if node_type == 'normal': + param_name = name.split(".")[-1] + if node_type == "normal": param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name) else: - if 'weight' in name: + if "weight" in name: param_sharding_spec = None for node in list(graph.nodes): - if 'weight' in node.name: + if "weight" in node.name: param_sharding_spec = node.sharding_spec - elif 'bias' in name: + elif "bias" in name: param_sharding_spec = None for node in list(graph.nodes): - if 'bias' in node.name: + if "bias" in node.name: param_sharding_spec = node.sharding_spec assert param_sharding_spec is not None grad_sharded = param_to_shard_dict[name].grad grad_to_compare = param_to_compare_dict[name].grad global_grad = to_global(grad_sharded, param_sharding_spec) - assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type='param grad') + assert_close_helper(global_grad, grad_to_compare, strategy_index=strategy_index, type="param grad") -def assert_close_helper(first: torch.Tensor, - second: torch.Tensor, - rtol: float = 1e-2, - atol: float = 1e-2, - strategy_index: int = -1, - type: str = 'not defined'): +def assert_close_helper( + first: torch.Tensor, + second: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2, + strategy_index: int = -1, + type: str = "not defined", +): """ This method is used to check whether the average difference between two tensors is as close as expected. """ @@ -189,4 +199,4 @@ def assert_close_helper(first: torch.Tensor, else: assert_close(first, second, rtol=rtol, atol=atol) except: - print(f'strategy index {strategy_index} encounter assert_close error on {type}') + print(f"strategy index {strategy_index} encounter assert_close error on {type}") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index 0d93e4e40527..e7b8c696e62e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -3,17 +3,18 @@ from torchvision.models import resnet50 from colossalai._analyzer.fx.passes import shape_prop_pass + # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing import clear_cache_before_run, run_on_environment_flag -@run_on_environment_flag(name='AUTO_PARALLEL') +@run_on_environment_flag(name="AUTO_PARALLEL") @clear_cache_before_run() def test_cost_graph(): physical_mesh_id = torch.arange(0, 8) @@ -21,11 +22,11 @@ def test_cost_graph(): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - shape_consistency_manager = ShapeConsistencyManager() + ShapeConsistencyManager() tracer = ColoTracer(bias_addition_split=True) model = resnet50(num_classes=100000) - input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')} + input_sample = {"x": torch.rand(128, 3, 224, 224).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) # graph(): @@ -74,7 +75,7 @@ def test_cost_graph(): communication_cost_bn = 0 memory_cost = 0 for index, node in enumerate(graph.nodes): - if node.op == 'call_module': + if node.op == "call_module": submod = node.graph.owning_module.get_submodule(node.target) if type(submod) in BATCHNORM_MODULE_OP: communication_cost_bn += node.strategies_vector[strategies_list[index]].communication_cost.total @@ -86,11 +87,11 @@ def test_cost_graph(): node_memory_cost = node_memory_cost[0] memory_cost += node_memory_cost.activation + node_memory_cost.parameter - print(f'computation cost is {computation_cost}') - print(f'communication cost is {communication_cost}') - print(f'memory cost is {memory_cost}') - print(f'bn communication cost is {communication_cost_bn}') + print(f"computation cost is {computation_cost}") + print(f"communication cost is {communication_cost}") + print(f"memory cost is {memory_cost}") + print(f"bn communication cost is {communication_cost_bn}") -if __name__ == '__main__': +if __name__ == "__main__": test_cost_graph() diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py index d07145e48e1f..07fd0ad582e9 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py +++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List +from typing import Any import torch import torch.fx @@ -111,13 +111,14 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_evoformer_stack(data_args): from test_autochunk_evoformer_stack import get_data, get_model + print("\nmsa len: %d, pair len: %d" % (data_args[0], data_args[1])) max_mem = _benchmark_evoformer_stack_origin(data_args, get_model, get_data) for ratio in [0.5, 0.4, 0.3, 0.2, 0.1]: try: _benchmark_evoformer_stack_gm(data_args, max_mem * ratio, get_model, get_data) except RuntimeError as e: - if e.args[0] == 'Search failed. Try a larger memory threshold.': + if e.args[0] == "Search failed. Try a larger memory threshold.": break except Exception as e: raise e diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py index 593658fd1368..3d3f212a68d0 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py @@ -8,7 +8,6 @@ from colossalai.autochunk.utils import flat_list from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.legacy.core import global_context as gpc from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: @@ -80,9 +79,9 @@ def assert_codegen_run( out_gm = flat_list(out_gm) out_model = flat_list(out_model) for out_gm_i, out_model_i in zip(out_gm, out_model): - assert torch.allclose(out_gm_i, out_model_i, - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_gm_i - out_model_i)) + assert torch.allclose( + out_gm_i, out_model_i, atol=1e-4 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(out_gm_i - out_model_i)) return chunks diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py index 9e4cb7ee9f95..1a4ababda30d 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py @@ -6,6 +6,7 @@ try: from fastfold.model.nn.evoformer import EvoformerBlock + HAS_REPO = True except: HAS_REPO = False @@ -17,22 +18,26 @@ def get_model(): - model = EvoformerBlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - is_multimer=False, - ).eval().cuda() + model = ( + EvoformerBlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + is_multimer=False, + ) + .eval() + .cuda() + ) return model @@ -54,8 +59,20 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: def get_chunk_target() -> Dict: return { - None: [(120, 126), (225, 244), (270, 289), (306, 311), (70, 106), (23, 46), (146, 152), (187, 193), (181, 184), - (140, 145), (162, 163), (203, 204)], + None: [ + (120, 126), + (225, 244), + (270, 289), + (306, 311), + (70, 106), + (23, 46), + (146, 152), + (187, 193), + (181, 184), + (140, 145), + (162, 163), + (203, 204), + ], 20: [(120, 123), (232, 237), (277, 282), (305, 306)], 24: [(122, 123)], } diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py index 6b47033e199f..0b04ba5257b6 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py @@ -6,6 +6,7 @@ try: from fastfold.model.nn.evoformer import EvoformerStack + HAS_REPO = True except: HAS_REPO = False @@ -17,26 +18,30 @@ def get_model(): - model = EvoformerStack( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - c_s=384, - no_heads_msa=8, - no_heads_pair=4, - no_blocks=2, # 48 - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.25, - blocks_per_ckpt=None, - inf=1000000000.0, - eps=1e-08, - clear_cache_between_blocks=False, - is_multimer=False, - ).eval().cuda() + model = ( + EvoformerStack( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + c_s=384, + no_heads_msa=8, + no_heads_pair=4, + no_blocks=2, # 48 + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.25, + blocks_per_ckpt=None, + inf=1000000000.0, + eps=1e-08, + clear_cache_between_blocks=False, + is_multimer=False, + ) + .eval() + .cuda() + ) return model @@ -62,7 +67,7 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: ) @clear_cache_before_run() @parameterize("max_memory", [None, 20, 24]) -@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_evoformer_stack(data_args, max_memory): spawn( run_test, diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py index b4c577c18ee6..585a9e3381c4 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import List, Tuple import pytest import torch @@ -6,6 +6,7 @@ try: from fastfold.model.nn.evoformer import ExtraMSABlock + HAS_REPO = True except: HAS_REPO = False @@ -16,23 +17,27 @@ def get_model(): - model = ExtraMSABlock( - c_m=256, - c_z=128, - c_hidden_msa_att=32, - c_hidden_opm=32, - c_hidden_mul=128, - c_hidden_pair_att=32, - no_heads_msa=8, - no_heads_pair=4, - transition_n=4, - msa_dropout=0.15, - pair_dropout=0.15, - inf=1e4, - eps=1e-4, - ckpt=False, - is_multimer=False, - ).eval().cuda() + model = ( + ExtraMSABlock( + c_m=256, + c_z=128, + c_hidden_msa_att=32, + c_hidden_opm=32, + c_hidden_mul=128, + c_hidden_pair_att=32, + no_heads_msa=8, + no_heads_pair=4, + transition_n=4, + msa_dropout=0.15, + pair_dropout=0.15, + inf=1e4, + eps=1e-4, + ckpt=False, + is_multimer=False, + ) + .eval() + .cuda() + ) return model @@ -58,7 +63,7 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: ) @clear_cache_before_run() @parameterize("max_memory", [None, 20, 24]) -@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_extramsa_block(data_args, max_memory): spawn( run_test, diff --git a/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py index 6fb7efa7a8fc..b75cbe67590c 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py +++ b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List +from typing import Any import torch import torch.fx @@ -64,8 +64,10 @@ def _benchmark_autochunk_unet_gm( para_mem = float(parameter_size(model)) / 1024**2 act_mem = _benchmark_memory(gm, inputs) speed = _benchmark_speed(gm, inputs) - print("unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) def _benchmark_autochunk_unet_origin( @@ -86,8 +88,10 @@ def _benchmark_autochunk_unet_origin( para_mem = float(parameter_size(model)) / 1024**2 act_mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) return act_mem @@ -115,6 +119,7 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_autochunk_unet(batch=1, height=448, width=448): from test_autochunk_unet import UNet2DModel, get_data + model = UNet2DModel() latent_shape = (batch, 3, height // 7, width // 7) @@ -124,7 +129,7 @@ def benchmark_autochunk_unet(batch=1, height=448, width=448): try: _benchmark_autochunk_unet_gm(model, get_data(latent_shape), max_mem * ratio) except RuntimeError as e: - if e.args[0] == 'Search failed. Try a larger memory threshold.': + if e.args[0] == "Search failed. Try a larger memory threshold.": break except Exception as e: raise e diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index 264331a5fef0..32034992090f 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -83,9 +83,11 @@ def assert_codegen_run( max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) - assert torch.allclose(out_gm["sample"], out_model["sample"], - atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_gm["sample"] - out_model["sample"])) + assert torch.allclose( + out_gm["sample"], out_model["sample"], atol=1e-3 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( + torch.abs(out_gm["sample"] - out_model["sample"]) + ) return chunks @@ -129,7 +131,7 @@ def run_test( if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] - assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), ) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index f0cf2a5fcbca..ad50874c92a3 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -5,9 +5,11 @@ try: import diffusers + MODELS = [diffusers.UNet2DModel] HAS_REPO = True from packaging import version + SKIP_UNET_TEST = version.parse(diffusers.__version__) > version.parse("0.10.2") except: MODELS = [] diff --git a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py index 63490aaee7ff..e70e50175032 100644 --- a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py +++ b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List +from typing import Any import torch import torch.fx @@ -64,8 +64,10 @@ def _benchmark_autochunk_gpt_gm( para_mem = float(parameter_size(model)) / 1024**2 * 6 act_mem = _benchmark_memory(gm, inputs) speed = _benchmark_speed(gm, inputs) - print("gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) def _benchmark_autochunk_gpt_origin( @@ -86,8 +88,10 @@ def _benchmark_autochunk_gpt_origin( para_mem = float(parameter_size(model)) / 1024**2 * 6 act_mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" % - (speed, act_mem, para_mem, act_mem + para_mem)) + print( + "gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" + % (speed, act_mem, para_mem, act_mem + para_mem) + ) return act_mem @@ -115,6 +119,7 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): from test_autochunk_gpt import GPT2Config, GPT2Model, get_data + model = GPT2Model config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head) model = model(config=config) @@ -125,7 +130,7 @@ def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): try: _benchmark_autochunk_gpt_gm(model, get_data(shape), max_mem * ratio) except RuntimeError as e: - if e.args[0] == 'Search failed. Try a larger memory threshold.': + if e.args[0] == "Search failed. Try a larger memory threshold.": break except Exception as e: raise e diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py index 82af6c05c6ef..b2d842ee6a7b 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py @@ -5,6 +5,7 @@ try: from transformers import GPT2Config, GPT2Model + MODELS = [GPT2Model] HAS_REPO = True except: @@ -52,13 +53,15 @@ def test_autochunk_gpt(model, shape, max_memory): if __name__ == "__main__": - run_test(rank=0, - data=get_data((BATCH_SIZE, SEQ_LENGTH)), - max_memory=None, - model=GPT2Model, - config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), - print_code=False, - print_est_mem=False, - print_mem=False, - print_progress=False, - eval_mem=False) + run_test( + rank=0, + data=get_data((BATCH_SIZE, SEQ_LENGTH)), + max_memory=None, + model=GPT2Model, + config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), + print_code=False, + print_est_mem=False, + print_mem=False, + print_progress=False, + eval_mem=False, + ) diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py index 5c863b0df47f..77c11db71a5c 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py @@ -38,11 +38,9 @@ def assert_codegen_run( meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors] interp.propagate(*meta_tensors) - codegen = AutoChunkCodeGen(meta_graph, - max_memory=max_memory, - print_mem=print_est_mem, - print_progress=print_progress, - eval_mem=eval_mem) + codegen = AutoChunkCodeGen( + meta_graph, max_memory=max_memory, print_mem=print_est_mem, print_progress=print_progress, eval_mem=eval_mem + ) chunks = codegen.chunk_infos # trace and recompile @@ -85,9 +83,9 @@ def assert_allclose(out_model: Any, out_gm: Any) -> None: assert allclose for out """ if isinstance(out_model, torch.Tensor): - assert torch.allclose(out_model, out_gm, - atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_model - out_gm)) + assert torch.allclose( + out_model, out_gm, atol=1e-4 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(out_model - out_gm)) elif isinstance(out_model, dict): for k in out_model.keys(): assert_allclose(out_model[k], out_gm[k]) @@ -123,19 +121,21 @@ def run_test( ) # build model and input - chunks = assert_codegen_run(model, - data=data, - max_memory=max_memory, - print_code=print_code, - print_est_mem=print_est_mem, - print_mem=print_mem, - print_progress=print_progress, - eval_mem=eval_mem) + chunks = assert_codegen_run( + model, + data=data, + max_memory=max_memory, + print_code=print_code, + print_est_mem=print_est_mem, + print_mem=print_mem, + print_progress=print_progress, + eval_mem=eval_mem, + ) if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] - assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), ) diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py index a98aa0e03954..aa868d683f06 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py @@ -5,6 +5,7 @@ try: from timm.models.vision_transformer import vit_large_patch16_384 as vit + MODELS = [vit] HAS_REPO = True except: @@ -19,7 +20,7 @@ def get_data() -> Tuple[List, List]: data = torch.rand(1, 3, 384, 384) - meta_args = {'x': data} + meta_args = {"x": data} return data, meta_args diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py index 65d1e9c4d090..ca919fb7e4fe 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py @@ -75,9 +75,9 @@ def assert_codegen_run( max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2 print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm)) - assert torch.allclose(out_gm, out_model, - atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean( - torch.abs(out_gm - out_model)) + assert torch.allclose( + out_gm, out_model, atol=1e-3 + ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(torch.abs(out_gm - out_model)) return chunks @@ -121,7 +121,7 @@ def run_test( if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks] chunk_target = get_chunk_target()[max_memory] - assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % ( + assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % ( str(chunk_found), str(chunk_target), ) diff --git a/tests/test_booster/test_accelerator.py b/tests/test_booster/test_accelerator.py index 6f3f66ed41b8..777589299d13 100644 --- a/tests/test_booster/test_accelerator.py +++ b/tests/test_booster/test_accelerator.py @@ -5,7 +5,7 @@ @clear_cache_before_run() -@parameterize('device', ['cpu', 'cuda']) +@parameterize("device", ["cpu", "cuda"]) def test_accelerator(device): accelerator = Accelerator(device) model = nn.Linear(8, 8) diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index 26ce00e94869..3aefb37974f0 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -9,11 +9,11 @@ def run_torch_amp(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - sub_model_zoo = model_zoo.get_sub_registry('timm') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") + sub_model_zoo = model_zoo.get_sub_registry("timm") for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items(): # dlrm_interactionarch has not parameters, so skip - if name == 'dlrm_interactionarch': + if name == "dlrm_interactionarch": continue model = model_fn().cuda() @@ -21,7 +21,7 @@ def run_torch_amp(rank, world_size, port): criterion = lambda x: x.mean() data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } mixed_precision = FP16TorchMixedPrecision() model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion) diff --git a/tests/test_booster/test_plugin/test_3d_plugin.py b/tests/test_booster/test_plugin/test_3d_plugin.py index a58afac810d7..ad878fb0c86a 100644 --- a/tests/test_booster/test_plugin/test_3d_plugin.py +++ b/tests/test_booster/test_plugin/test_3d_plugin.py @@ -16,11 +16,11 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: try: - if init_method == 'lazy': + if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() - plugin = HybridParallelPlugin(tp_size=2, pp_size=2, num_microbatches=4, precision='bf16') + plugin = HybridParallelPlugin(tp_size=2, pp_size=2, num_microbatches=4, precision="bf16") booster = Booster(plugin=plugin) with ctx: model = model_fn() @@ -29,7 +29,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ data = data_gen_fn() data = { - k: v.to('cuda').repeat(4, 1) if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v + k: v.to("cuda").repeat(4, 1) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } @@ -50,23 +50,24 @@ def _criterion(outputs, inputs): return repr(e) -@parameterize('init_method', ['none', 'lazy']) -def check_3d_plugin(init_method: str = 'none', early_stop: bool = True): +@parameterize("init_method", ["none", "lazy"]) +def check_3d_plugin(init_method: str = "none", early_stop: bool = True): """check gemini plugin over model zoo Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ is_support_meta = is_compatible_with_meta() - if not is_support_meta and init_method == 'lazy': + if not is_support_meta and init_method == "lazy": return passed_models = [] - failed_info = {} # (model_name, error) pair + failed_info = {} # (model_name, error) pair # TODO(ver217): add more models - for name, (model_fn, data_gen_fn, output_transform_fn, _, - _) in model_zoo.get_sub_registry('transformers_llama_for_casual_lm').items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry( + "transformers_llama_for_casual_lm" + ).items(): err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() @@ -78,15 +79,15 @@ def check_3d_plugin(init_method: str = 'none', early_stop: bool = True): break if dist.get_rank() == 0: - print(f'Init method: {init_method}') - print(f'Passed models({len(passed_models)}): {passed_models}\n\n') - print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') - assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + print(f"Init method: {init_method}") + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_3d_plugin(early_stop=early_stop) @@ -95,5 +96,5 @@ def test_gemini_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) -if __name__ == '__main__': +if __name__ == "__main__": test_gemini_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py index 689b334cae50..0ac9d0f6d409 100644 --- a/tests/test_booster/test_plugin/test_dp_plugin_base.py +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -15,8 +15,7 @@ class DPPluginWrapper(DPPluginBase): - """This is a wrapper class for testing DP plugin initialization and dataloader creation. - """ + """This is a wrapper class for testing DP plugin initialization and dataloader creation.""" def configure( self, @@ -73,13 +72,14 @@ def check_dataloader_sharding(): # compare on rank 0 if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' + assert not torch.equal( + batch, batch_to_compare + ), "Same number was found across ranks but expected it to be different" def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_dataloader_sharding() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 18be68bf6e48..00ff6cb37d2a 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -17,7 +17,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: try: - if init_method == 'lazy': + if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() @@ -30,13 +30,13 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) for n, p in model.named_parameters(): - assert isinstance(p, ColoParameter), f'{n} is not a ColoParameter' + assert isinstance(p, ColoParameter), f"{n} is not a ColoParameter" output = model(**data) output = output_transform_fn(output) @@ -55,47 +55,65 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ # @parameterize('init_method', ['lazy', 'none', 'colo']) -@parameterize('subset', ['torchvision', 'transformers', 'diffusers']) -@parameterize('init_method', ['none']) -def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool = True): +@parameterize("subset", ["torchvision", "transformers", "diffusers"]) +@parameterize("init_method", ["none"]) +def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True): """check gemini plugin over model zoo Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ is_support_meta = is_compatible_with_meta() - if not is_support_meta and init_method == 'lazy': + if not is_support_meta and init_method == "lazy": return passed_models = [] - failed_info = {} # (model_name, error) pair + failed_info = {} # (model_name, error) pair for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(subset).items(): # These models lead to CUDA error - if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', - 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext', - 'torchvision_convnext_base'): + if name in ( + "diffusers_auto_encoder_kl", + "diffusers_vq_model", + "diffusers_unet2d_model", + "timm_resmlp", + "timm_gmixer_12_224", + "timm_gmlp_b16_224", + "timm_mixer_b16_224", + "timm_convnext", + "torchvision_convnext_base", + ): continue # These models are not compatible with gemini if name in [ - 'timm_convit', - 'timm_dm_nfnet', - 'torchvision_vit_b_16', - 'transformers_t5', - 'transformers_t5_for_conditional_generation', - 'transformers_t5_encoder_model', # does not support apex rmsnorm - 'transformers_chatglm', - 'transformers_sam', - 'transformers_vit', - 'transformers_gpt_double_heads', # TODO check why does the model fail to run using Gemini + "timm_convit", + "timm_dm_nfnet", + "torchvision_vit_b_16", + "transformers_t5", + "transformers_t5_for_conditional_generation", + "transformers_t5_encoder_model", # does not support apex rmsnorm + "transformers_chatglm", + "transformers_sam", + "transformers_vit", + "transformers_gpt_double_heads", # TODO check why does the model fail to run using Gemini ]: continue - if init_method == 'lazy' and name in [ - 'timm_convmixer', 'timm_vision_transformer', 'timm_deit', 'timm_deit3', 'timm_inception_v3', - 'timm_tnt_b_patch16_224', 'timm_rexnet', 'torchvision_densenet121', 'torchvision_efficientnet_b0', - 'torchvision_mobilenet_v2', 'torchvision_mnasnet0_5', 'torchvision_regnet_x_16gf', - 'torchvision_shufflenet_v2_x0_5', 'torchvision_efficientnet_v2_s' + if init_method == "lazy" and name in [ + "timm_convmixer", + "timm_vision_transformer", + "timm_deit", + "timm_deit3", + "timm_inception_v3", + "timm_tnt_b_patch16_224", + "timm_rexnet", + "torchvision_densenet121", + "torchvision_efficientnet_b0", + "torchvision_mobilenet_v2", + "torchvision_mnasnet0_5", + "torchvision_regnet_x_16gf", + "torchvision_shufflenet_v2_x0_5", + "torchvision_efficientnet_v2_s", ]: continue err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) @@ -108,15 +126,15 @@ def check_gemini_plugin(subset: str, init_method: str = 'none', early_stop: bool break if dist.get_rank() == 0: - print(f'Init method: {init_method}') - print(f'Passed models({len(passed_models)}): {passed_models}\n\n') - print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') - assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + print(f"Init method: {init_method}") + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_gemini_plugin(early_stop=early_stop) @@ -125,5 +143,5 @@ def test_gemini_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) -if __name__ == '__main__': +if __name__ == "__main__": test_gemini_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 79f98a4c95d0..9cc12f96bd4d 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -11,9 +11,9 @@ from tests.kit.model_zoo import model_zoo # These models are not compatible with AMP -_AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch'] +_AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"] # These models have no parameters -_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] +_LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: @@ -26,7 +26,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -43,7 +43,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: return repr(e) -@parameterize('stage', [2]) +@parameterize("stage", [2]) def check_low_level_zero_plugin(stage: int, early_stop: bool = True): """check low level zero plugin over model zoo @@ -52,7 +52,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ passed_models = [] - failed_info = {} # (model_name, error) pair + failed_info = {} # (model_name, error) pair ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS skipped_models = [] @@ -73,15 +73,15 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): break if dist.get_rank() == 0: - print(f'Passed models({len(passed_models)}): {passed_models}\n\n') - print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') - print(f'Skipped models({len(skipped_models)}): {skipped_models}\n\n') - assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) + print(f"Passed models({len(passed_models)}): {passed_models}\n\n") + print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n") + print(f"Skipped models({len(skipped_models)}): {skipped_models}\n\n") + assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()]) def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_plugin(early_stop=early_stop) @@ -90,5 +90,5 @@ def test_low_level_zero_plugin(early_stop: bool = True): spawn(run_dist, 4, early_stop=early_stop) -if __name__ == '__main__': +if __name__ == "__main__": test_low_level_zero_plugin(early_stop=False) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 23d743c924aa..1a7ca6f2a30c 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -22,7 +22,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): criterion = lambda x: x.mean() data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -41,14 +41,13 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_ddp_plugin(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): - if name == 'dlrm_interactionarch': + if name == "dlrm_interactionarch": continue run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() class DummyModel(nn.Module): - def __init__(self): super().__init__() self.weight = nn.Parameter(torch.rand(1)) @@ -67,10 +66,9 @@ def check_torch_ddp_no_sync(): # create a custom dataset with 0 to 10 dataset = torch.arange(0, 10) train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2) - model, optimizer, criterion, train_dataloader, _ = booster.boost(model, - optimizer, - criterion, - dataloader=train_dataloader) + model, optimizer, criterion, train_dataloader, _ = booster.boost( + model, optimizer, criterion, dataloader=train_dataloader + ) def fwd_bwd(): output = model(batch.cuda()) @@ -105,7 +103,7 @@ def get_grad_set_over_all_ranks(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_torch_ddp_plugin() check_torch_ddp_no_sync() diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index e09ad766bb32..8bcbffdd06fe 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -6,7 +6,7 @@ import colossalai from colossalai.booster import Booster -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse("1.12.0"): from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from colossalai.booster.plugin import TorchFSDPPlugin @@ -24,7 +24,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): criterion = lambda x: x.mean() data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -43,10 +43,16 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_fsdp_plugin(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): - if any(element in name for element in [ - 'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet', - 'torchvision_inception_v3' - ]): + if any( + element in name + for element in [ + "diffusers", + "deepfm_sparsearch", + "dlrm_interactionarch", + "torchvision_googlenet", + "torchvision_inception_v3", + ] + ): continue run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() @@ -54,11 +60,11 @@ def check_torch_fsdp_plugin(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_torch_fsdp_plugin() -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="requires torch1.12 or higher") @rerun_if_address_is_in_use() def test_torch_fsdp_plugin(): spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 6720be58490b..d66dec113017 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -19,50 +19,30 @@ from tests.kit.model_zoo import model_zoo MODEL_PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half ] OPTIM_PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 1.0 - }, # zero2-offload - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.5 - }, # zero2-offload-half + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half ] @clear_cache_before_run() -@parameterize('placement_config', MODEL_PLACEMENT_CONFIGS) -@parameterize('model_name', ['transformers_bert_for_sequence_classification']) -@parameterize('use_safetensors', [False, True]) +@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) +@parameterize("model_name", ["transformers_bert_for_sequence_classification"]) +@parameterize("use_safetensors", [False, True]) def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool): from transformers import BertForSequenceClassification + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() with shared_tempdir() as tempdir: - pretrained_path = os.path.join(tempdir, 'pretrained') + pretrained_path = os.path.join(tempdir, "pretrained") bert_model.config.save_pretrained(save_directory=pretrained_path) plugin = GeminiPlugin(**placement_config) @@ -70,24 +50,22 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 - booster.save_model(bert_model, - pretrained_path, - True, - True, - '', (model_size / 3), - use_safetensors=use_safetensors) + booster.save_model( + bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors + ) dist.barrier() new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) - check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32), - new_bert_model.state_dict(), False) + check_state_dict_equal( + bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False + ) @clear_cache_before_run() -@parameterize('placement_config', OPTIM_PLACEMENT_CONFIGS) -@parameterize('shard', [False, True]) -@parameterize('model_name', ['transformers_gpt']) -@parameterize('size_per_shard', [32]) +@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_gpt"]) +@parameterize("size_per_shard", [32]) def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() @@ -102,7 +80,7 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} output = model(**data) output = output_transform_fn(output) output_key = list(output.keys())[0] @@ -123,13 +101,14 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), - False) + check_state_dict_equal( + optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False + ) # Check the new model/optimizer can successfully run. data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } output = new_model(**data) output = output_transform_fn(output) @@ -143,13 +122,13 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_state_dict_with_origin() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index 4569ea12d82d..d46e5380d944 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -19,10 +19,9 @@ @clear_cache_before_run() -@parameterize('shard', [False, True]) -@parameterize('model_name', ['transformers_gpt']) +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_gpt"]) def exam_torch_load_from_gemini(shard: bool, model_name: str): - (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14)) @@ -33,7 +32,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} output = model(**data) output = output_transform_fn(output) output_key = list(output.keys())[0] @@ -60,8 +59,11 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): new_booster.load_model(new_model, model_ckpt_path, strict=True) # Add prefix to get aligned with pytorch parameter names. - check_state_dict_equal(model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), - new_model.state_dict(), False) + check_state_dict_equal( + model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + new_model.state_dict(), + False, + ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False) @@ -69,7 +71,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): # Check the new model/optimizer can successfully run. data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } output = new_model(**data) output = output_transform_fn(output) @@ -82,10 +84,9 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): @clear_cache_before_run() -@parameterize('shard', [False, True]) -@parameterize('model_name', ['transformers_gpt']) +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_gpt"]) def exam_gemini_load_from_torch(shard: bool, model_name: str): - (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() plugin = TorchDDPPlugin() @@ -96,7 +97,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} output = model(**data) output = output_transform_fn(output) output_key = list(output.keys())[0] @@ -123,8 +124,11 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): new_booster.load_model(new_model, model_ckpt_path, strict=True) # Add prefix to get aligned with pytorch parameter names. - check_state_dict_equal(new_model.state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), - model.state_dict(), False) + check_state_dict_equal( + new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + model.state_dict(), + False, + ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) old_state_dict = optimizer.state_dict() @@ -132,18 +136,19 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): # Comparison of param_groups needs special care here, # since not all hyperparameters in Adam are used by HybridAdam - hyperparameters_to_examine = ['params', 'lr', 'betas', 'eps', 'weight_decay'] - for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']): + hyperparameters_to_examine = ["params", "lr", "betas", "eps", "weight_decay"] + for old_group, new_group in zip(old_state_dict["param_groups"], new_state_dict["param_groups"]): for k in hyperparameters_to_examine: - assert k in old_group and k in new_group, \ - f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" + assert ( + k in old_group and k in new_group + ), f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" assert old_group[k] == new_group[k] - check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False) + check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], False) # Check the new model/optimizer can successfully run. data = data_gen_fn() data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } output = new_model(**data) output = output_transform_fn(output) @@ -157,13 +162,13 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_torch_load_from_gemini() exam_gemini_load_from_torch() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 0976d4503a61..2a046a298dd7 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -5,7 +5,6 @@ from torch.optim import Adam from torchvision.models import resnet18 -from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.testing import check_state_dict_equal, clear_cache_before_run, parameterize @@ -18,7 +17,7 @@ @clear_cache_before_run() -@parameterize('use_safetensors', [True, False]) +@parameterize("use_safetensors", [True, False]) def test_unsharded_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() @@ -59,7 +58,7 @@ def test_unsharded_checkpoint(use_safetensors: bool): check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) -@pytest.mark.parametrize('use_safetensors', [True, False]) +@pytest.mark.parametrize("use_safetensors", [True, False]) def test_sharded_model_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() @@ -75,11 +74,9 @@ def test_sharded_model_checkpoint(use_safetensors: bool): # create a temp file for checkpoint if use_safetensors: - suffix = ".safetensors" - SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" + pass else: - suffix = ".bin" - WEIGHTS_INDEX_NAME = "model.bin.index.json" + pass model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() @@ -103,7 +100,6 @@ def test_sharded_model_checkpoint(use_safetensors: bool): def test_sharded_optimizer_checkpoint(): - # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -162,16 +158,11 @@ def test_sharded_optimizer_checkpoint(): def test_sharded_optimizer_multiple_param_groups(): - # create a model and optimizer model = resnet18() - optimizer = Adam([{ - 'params': model.layer1.parameters() - }, { - 'params': model.layer2.parameters(), - 'lr': 0.002 - }], - lr=0.001) + optimizer = Adam( + [{"params": model.layer1.parameters()}, {"params": model.layer2.parameters(), "lr": 0.002}], lr=0.001 + ) # create test data sample x = torch.randn(1, 3, 224, 224) @@ -194,13 +185,9 @@ def test_sharded_optimizer_multiple_param_groups(): # create new model new_model = resnet18() - new_optimizer = Adam([{ - 'params': new_model.layer1.parameters() - }, { - 'params': new_model.layer2.parameters(), - 'lr': 0.002 - }], - lr=0.001) + new_optimizer = Adam( + [{"params": new_model.layer1.parameters()}, {"params": new_model.layer2.parameters(), "lr": 0.002}], lr=0.001 + ) ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) ckpt_io.load_optimizer(new_optimizer, str(optimizer_ckpt_dir.name)) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index e43908e0c651..e8bb8f9e3475 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -22,37 +22,26 @@ # TODO (Baizhou): Add test cases for shard=False @clear_cache_before_run() -@parameterize('shard', [True]) -@parameterize('model_name', ['transformers_gpt']) -@parameterize('size_per_shard', [32]) -@parameterize('test_config', [{ - 'tp_size': 4, - 'pp_size': 1, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 2, - 'pp_size': 1, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize("shard", [True]) +@parameterize("model_name", ["transformers_gpt"]) +@parameterize("size_per_shard", [32]) +@parameterize( + "test_config", + [ + { + "tp_size": 4, + "pp_size": 1, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 2, "num_microbatches": 4, "precision": "fp16", "initial_scale": 1}, + {"tp_size": 2, "pp_size": 1, "zero_stage": 2, "precision": "fp16", "initial_scale": 1}, + {"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, + ], +) def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): - - (model_fn, data_gen_fn, output_transform_fn, loss_fn, - _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) criterion = loss_fn plugin = HybridParallelPlugin(**test_config) booster = Booster(plugin=plugin) @@ -65,10 +54,10 @@ def _criterion(outputs, inputs): def _preprocess_data(data): if booster.plugin.stage_manager is not None: for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) + data[k] = v.to("cuda").repeat(*new_shape) return iter([data]) else: return {k: v.cuda() for k, v in data.items()} @@ -80,12 +69,9 @@ def _preprocess_data(data): data = data_gen_fn() model.train() if booster.plugin.stage_manager is not None: - booster.execute_pipeline(_preprocess_data(data), - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) + booster.execute_pipeline( + _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + ) else: output = model(**_preprocess_data(data)) loss = criterion(output) @@ -94,7 +80,6 @@ def _preprocess_data(data): optimizer.step() with shared_tempdir() as tempdir: - model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) @@ -115,18 +100,12 @@ def _preprocess_data(data): model.train() new_model.train() if booster.plugin.stage_manager is not None: - booster.execute_pipeline(_preprocess_data(data), - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) - booster.execute_pipeline(_preprocess_data(data), - new_model, - _criterion, - new_optimizer, - return_loss=True, - return_outputs=False) + booster.execute_pipeline( + _preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False + ) + booster.execute_pipeline( + _preprocess_data(data), new_model, _criterion, new_optimizer, return_loss=True, return_outputs=False + ) else: old_model_loss = criterion(model(**_preprocess_data(data))) optimizer.backward(old_model_loss) @@ -141,10 +120,9 @@ def _preprocess_data(data): if stage_manager is None or stage_manager.is_first_stage(): assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) - assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data, - new_model.unwrap().h[0].mlp.c_fc.weight.data, - atol=5e-3, - rtol=5e-3) + assert_close_loose( + model.unwrap().h[0].mlp.c_fc.weight.data, new_model.unwrap().h[0].mlp.c_fc.weight.data, atol=5e-3, rtol=5e-3 + ) dist.barrier() Randomizer.reset_index() @@ -153,12 +131,12 @@ def _preprocess_data(data): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) +@pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_hybrid_ckpIO(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 7ee733b26b3f..8a4724c8a82c 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -20,9 +20,9 @@ # stage 1 and 2 process the optimizer/mode the same way # only test 2 is fine @clear_cache_before_run() -@parameterize('stage', [2]) -@parameterize('shard', [True, False]) -@parameterize('offload', [False, True]) +@parameterize("stage", [2]) +@parameterize("shard", [True, False]) +@parameterize("offload", [False, True]) def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) booster = Booster(plugin=plugin) @@ -31,7 +31,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): optimizer = HybridAdam((model.parameters()), lr=0.001) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - x = torch.randn(1, 3, 224, 224, device='cuda') + x = torch.randn(1, 3, 224, 224, device="cuda") output = model(x) loss = criterion(output) booster.backward(loss, optimizer) @@ -60,15 +60,16 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): padding = new_optimizer._param_store.get_param_padding_size(working_param) padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] - assert torch.equal(working_shard, - master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)) + assert torch.equal( + working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device) + ) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) def run_dist(rank, world_size, port): - colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") check_low_level_zero_checkpointIO() torch.cuda.empty_cache() diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index bd041a5e2fd3..c3c30e666b10 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -1,5 +1,3 @@ -import os - import pytest import torch import torch.distributed as dist @@ -20,18 +18,19 @@ @clear_cache_before_run() -@parameterize('model_name', ['transformers_gpt']) -@parameterize('plugin_type', ['ddp', 'zero', 'gemini']) +@parameterize("model_name", ["transformers_gpt"]) +@parameterize("plugin_type", ["ddp", "zero", "gemini"]) def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32): - (model_fn, data_gen_fn, output_transform_fn, loss_fn, - _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) criterion = loss_fn - if plugin_type == 'ddp': + if plugin_type == "ddp": plugin = TorchDDPPlugin() - elif plugin_type == 'zero': + elif plugin_type == "zero": plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) - elif plugin_type == 'gemini': + elif plugin_type == "gemini": plugin = GeminiPlugin(precision="fp16", initial_scale=32) else: raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") @@ -44,7 +43,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) data = data_gen_fn() - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()} output = model(**data) loss = criterion(output) @@ -52,7 +51,6 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per optimizer.step() with shared_tempdir() as tempdir: - model_ckpt_path = f"{tempdir}/model" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() @@ -62,9 +60,10 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - if plugin_type == 'gemini': - check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), - new_model.unwrap().state_dict(only_rank_0=False), False) + if plugin_type == "gemini": + check_state_dict_equal( + model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False + ) else: check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) dist.barrier() @@ -72,12 +71,12 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_from_pretrained() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_huggingface_compatibility(world_size): spawn(run_dist, world_size) diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 14332b5b3fca..eeb04df0f42d 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -12,8 +12,8 @@ from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn -@parameterize('shard', [True, False]) -@parameterize('size_per_shard', [16, 128]) +@parameterize("shard", [True, False]) +@parameterize("size_per_shard", [16, 128]) def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) @@ -27,7 +27,7 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): assert isinstance(optimizer, OptimizerWrapper) x = torch.randn(4, 3, 224, 224) - x = x.to('cuda') + x = x.to("cuda") output = model(x) loss = criterion(output) booster.backward(loss, optimizer) @@ -47,9 +47,9 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): new_model = resnet18() new_optimizer = SGD((new_model.parameters()), lr=0.001) new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1) - new_model, new_optimizer, _, _, new_scheduler = booster.boost(new_model, - new_optimizer, - lr_scheduler=new_scheduler) + new_model, new_optimizer, _, _, new_scheduler = booster.boost( + new_model, new_optimizer, lr_scheduler=new_scheduler + ) booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) @@ -61,7 +61,7 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): def run_dist(rank, world_size, port): - colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host="localhost") check_torch_ddp_checkpointIO() diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index 2b6090bb1e29..dd41f8185c2b 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -1,7 +1,6 @@ import pytest import torch from packaging import version -from torch import nn from torch.optim import SGD from torchvision.models import resnet18 from utils import shared_tempdir @@ -9,11 +8,10 @@ import colossalai from colossalai.booster import Booster -if version.parse(torch.__version__) >= version.parse('1.12.0'): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +if version.parse(torch.__version__) >= version.parse("1.12.0"): from colossalai.booster.plugin import TorchFSDPPlugin -from colossalai.testing import rerun_if_address_is_in_use, spawn, check_state_dict_equal +from colossalai.testing import rerun_if_address_is_in_use, spawn def compare_nested_dict(dict1, dict2): @@ -72,15 +70,16 @@ def run_model(): booster.save_optimizer(optimizer, optim_ckpt_path, shard=False) full_msd = fsdp_model.state_dict() - #full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer) + # full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer) sharded_osd = optimizer.state_dict() import copy + sharded_osd = copy.deepcopy(sharded_osd) run_model() full_msd_updated = fsdp_model.state_dict() - #full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) + # full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) sharded_osd_updated = optimizer.state_dict() assert not compare_nested_dict(sharded_osd, sharded_osd_updated) @@ -92,9 +91,9 @@ def run_model(): booster.load_optimizer(optimizer, optim_ckpt_path) full_msd_restore = fsdp_model.state_dict() - #full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) + # full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True) sharded_osd_restore = optimizer.state_dict() - + assert compare_nested_dict(sharded_osd, sharded_osd_restore) assert compare_nested_dict(full_msd_restore, full_msd) outputs_sec = fsdp_model(inputs) @@ -103,11 +102,11 @@ def run_model(): def run_dist(rank, world_size, port): # init dist env - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") check_torch_fsdp_ckpt() -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="requires torch1.12 or higher") @rerun_if_address_is_in_use() def test_torch_fsdp_ckpt(): spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/utils.py b/tests/test_checkpoint_io/utils.py index 2d35e157f446..d14fc944267c 100644 --- a/tests/test_checkpoint_io/utils.py +++ b/tests/test_checkpoint_io/utils.py @@ -15,7 +15,7 @@ def shared_tempdir() -> Iterator[str]: try: obj = [tempdir] dist.broadcast_object_list(obj, src=0) - tempdir = obj[0] # use the same directory on all ranks + tempdir = obj[0] # use the same directory on all ranks yield tempdir finally: dist.barrier() diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py index bb818a275879..ab61cdae5bb0 100644 --- a/tests/test_cluster/test_device_mesh_manager.py +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -1,5 +1,3 @@ -import torch - from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -8,7 +6,7 @@ def check_device_mesh_manager(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") device_mesh_manager = DeviceMeshManager() # TODO(ver217): this test is strictly relies on hardware, temporary skip it # device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],) @@ -20,7 +18,7 @@ def check_device_mesh_manager(rank, world_size, port): physical_ids=[0, 1, 2, 3], mesh_shape=(2, 2), ) - device_mesh_with_shape = device_mesh_manager.create_device_mesh('1', device_mesh_info_with_shape) + device_mesh_with_shape = device_mesh_manager.create_device_mesh("1", device_mesh_info_with_shape) assert device_mesh_with_shape.shape == (2, 2) assert device_mesh_with_shape._logical_mesh_id.tolist() == [[0, 1], [2, 3]] @@ -30,5 +28,5 @@ def test_device_mesh_manager(): spawn(check_device_mesh_manager, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_device_mesh_manager() diff --git a/tests/test_cluster/test_process_group_mesh.py b/tests/test_cluster/test_process_group_mesh.py index 2304203d1e04..08542d1f64fa 100644 --- a/tests/test_cluster/test_process_group_mesh.py +++ b/tests/test_cluster/test_process_group_mesh.py @@ -15,13 +15,15 @@ def check_process_group_mesh_with_gpc(): # check world size assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size( - TP_DIM), f'{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}' + TP_DIM + ), f"{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}" assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM) assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM) # check locak rank (coordinate) assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate( - TP_DIM), f'{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}' + TP_DIM + ), f"{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}" assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM) assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM) @@ -37,21 +39,21 @@ def check_process_group_mesh_with_gpc(): coord = pg_mesh.coordinate() if not gpc.is_first_rank(ParallelMode.TENSOR): assert coord[TP_DIM] != 0 - prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1:] + prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1 :] assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape) if not gpc.is_first_rank(ParallelMode.PIPELINE): assert coord[PP_DIM] != 0 - prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1:] + prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1 :] assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape) # check next rank if not gpc.is_last_rank(ParallelMode.TENSOR): assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1 - next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1:] + next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1 :] assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape) if not gpc.is_last_rank(ParallelMode.PIPELINE): assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1 - next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1:] + next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1 :] assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape) @@ -108,35 +110,49 @@ def check_process_group_mesh_with_cases(): # check prev rank if RANK_TO_COORDINATE[rank][TP_DIM] != 0: - prev_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) + \ - RANK_TO_COORDINATE[rank][TP_DIM + 1:] + prev_coord = ( + RANK_TO_COORDINATE[rank][:TP_DIM] + + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) + + RANK_TO_COORDINATE[rank][TP_DIM + 1 :] + ) prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1] assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank if RANK_TO_COORDINATE[rank][PP_DIM] != 0: - prev_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) + \ - RANK_TO_COORDINATE[rank][PP_DIM + 1:] + prev_coord = ( + RANK_TO_COORDINATE[rank][:PP_DIM] + + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) + + RANK_TO_COORDINATE[rank][PP_DIM + 1 :] + ) prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1] assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank # check next rank if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1: - next_coord = RANK_TO_COORDINATE[rank][:TP_DIM] + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) + \ - RANK_TO_COORDINATE[rank][TP_DIM + 1:] + next_coord = ( + RANK_TO_COORDINATE[rank][:TP_DIM] + + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) + + RANK_TO_COORDINATE[rank][TP_DIM + 1 :] + ) next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1] assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1: - next_coord = RANK_TO_COORDINATE[rank][:PP_DIM] + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) + \ - RANK_TO_COORDINATE[rank][PP_DIM + 1:] + next_coord = ( + RANK_TO_COORDINATE[rank][:PP_DIM] + + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) + + RANK_TO_COORDINATE[rank][PP_DIM + 1 :] + ) next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1] assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank def run_dist(rank, world_size, port): - colossalai.launch(config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode='1d', size=2))), - rank=rank, - world_size=world_size, - port=port, - host='localhost') + colossalai.launch( + config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode="1d", size=2))), + rank=rank, + world_size=world_size, + port=port, + host="localhost", + ) # TODO(ver217): this function should be removed when gpc is removed # check_process_group_mesh_with_gpc() check_process_group_mesh_with_cases() @@ -147,5 +163,5 @@ def test_process_group_mesh(): spawn(run_dist, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_process_group_mesh() diff --git a/tests/test_config/sample_config.py b/tests/test_config/sample_config.py index 08ca108281b9..b9af7ab41a55 100644 --- a/tests/test_config/sample_config.py +++ b/tests/test_config/sample_config.py @@ -3,23 +3,23 @@ train_data = dict( dataset=dict( - type='CIFAR10Dataset', - root='/path/to/data', + type="CIFAR10Dataset", + root="/path/to/data", download=True, transform_pipeline=[ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ] + dict(type="RandomResizedCrop", size=224), + dict(type="RandomHorizontalFlip"), + dict(type="ToTensor"), + dict(type="Normalize", mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ], ), dataloader=dict( batch_size=64, pin_memory=True, num_workers=4, sampler=dict( - type='DataParallelSampler', + type="DataParallelSampler", shuffle=True, - ) - ) + ), + ), ) diff --git a/tests/test_config/test_load_config.py b/tests/test_config/test_load_config.py index 38b5e3f5f4fc..66e473459445 100644 --- a/tests/test_config/test_load_config.py +++ b/tests/test_config/test_load_config.py @@ -3,16 +3,15 @@ from pathlib import Path -import pytest - from colossalai.context.config import Config def test_load_config(): - filename = Path(__file__).parent.joinpath('sample_config.py') + filename = Path(__file__).parent.joinpath("sample_config.py") config = Config.from_file(filename) - assert config.train_data, 'cannot access train data as attribute' - assert config.train_data.dataset, 'cannot access grandchild attribute' - assert isinstance(config.train_data.dataset.transform_pipeline[0], dict), \ - f'expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}' + assert config.train_data, "cannot access train data as attribute" + assert config.train_data.dataset, "cannot access grandchild attribute" + assert isinstance( + config.train_data.dataset.transform_pipeline[0], dict + ), f"expected attribute transform_pipeline elements to be a dict, but found {type(config.train_data.dataset.transform_pipeline)}" diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py index ab933ed57d0d..f4a88f79c37b 100644 --- a/tests/test_device/test_alpha_beta.py +++ b/tests/test_device/test_alpha_beta.py @@ -8,7 +8,7 @@ def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) ab_dict = profiler.profile_ab() for _, (alpha, beta) in ab_dict.items(): @@ -17,11 +17,11 @@ def check_alpha_beta(rank, world_size, port, physical_devices): @pytest.mark.skip(reason="Skip because assertion fails for CI devices") @pytest.mark.dist -@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): spawn(check_alpha_beta, 4, physical_devices=physical_devices) -if __name__ == '__main__': +if __name__ == "__main__": test_profile_alpha_beta() diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 590d6966bff6..af44af5d9097 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -27,8 +27,8 @@ def check_1d_device_mesh(): # checks assert device_mesh.shape == [4] - assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict' - assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group' + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, "Expected 1 axis for the process group dict" + assert device_mesh.get_process_group(axis=0) == process_group, "Expected world process group" assert device_mesh.is_initialized assert device_mesh.num_devices == 4 assert device_mesh.is_initialized @@ -43,10 +43,10 @@ def check_2d_device_mesh(): first_col_ranks = [0, 2] second_col_ranks = [1, 3] - first_row_pg = dist.new_group(first_row_ranks, backend='nccl') - second_row_pg = dist.new_group(second_row_ranks, backend='nccl') - first_col_pg = dist.new_group(first_col_ranks, backend='nccl') - second_col_pg = dist.new_group(second_col_ranks, backend='nccl') + first_row_pg = dist.new_group(first_row_ranks, backend="nccl") + second_row_pg = dist.new_group(second_row_ranks, backend="nccl") + first_col_pg = dist.new_group(first_col_ranks, backend="nccl") + second_col_pg = dist.new_group(second_col_ranks, backend="nccl") # check for current_rank = dist.get_rank() @@ -65,9 +65,9 @@ def check_2d_device_mesh(): # checks assert device_mesh.shape == [2, 2] - assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict' - assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group' - assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group' + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, "Expected 2 axes for the process group dict" + assert device_mesh.get_process_group(axis=0) == col_pg, "Expected column process group" + assert device_mesh.get_process_group(axis=1) == row_pg, "Expected row process group" assert device_mesh.num_devices == 4 assert device_mesh.is_initialized assert device_mesh.logical_mesh_id is None @@ -75,7 +75,7 @@ def check_2d_device_mesh(): def check_init_from_process_group(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") @pytest.mark.dist @@ -84,6 +84,6 @@ def test_device_mesh_from_process_group(): spawn(check_init_from_process_group, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_device_mesh() test_device_mesh_from_process_group() diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py index 52604b9c6a49..34f2aacc18b2 100644 --- a/tests/test_device/test_extract_alpha_beta.py +++ b/tests/test_device/test_extract_alpha_beta.py @@ -8,7 +8,7 @@ def check_extract_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh() @@ -20,11 +20,11 @@ def check_extract_alpha_beta(rank, world_size, port, physical_devices): @pytest.mark.skip(reason="Skip because assertion may fail for CI devices") @pytest.mark.dist -@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): spawn(check_extract_alpha_beta, 4, physical_devices=physical_devices) -if __name__ == '__main__': +if __name__ == "__main__": test_profile_alpha_beta() diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index c18bf56752fb..3b398a917182 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -9,7 +9,7 @@ def check_layer(rank, world_size, port): - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) assert rank == dist.get_rank() @@ -33,5 +33,5 @@ def test_logical_pg(): spawn(check_layer, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_logical_pg() diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py index b22a76eabc2f..d9d4e79c1f57 100644 --- a/tests/test_device/test_search_logical_device_mesh.py +++ b/tests/test_device/test_search_logical_device_mesh.py @@ -8,7 +8,7 @@ def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") profiler = AlphaBetaProfiler(physical_devices) best_logical_mesh = profiler.search_best_logical_mesh() @@ -20,11 +20,11 @@ def check_alpha_beta(rank, world_size, port, physical_devices): @pytest.mark.skip(reason="Skip because assertion may fail for CI devices") @pytest.mark.dist -@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) +@parameterize("physical_devices", [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): spawn(check_alpha_beta, 4, physical_devices=physical_devices) -if __name__ == '__main__': +if __name__ == "__main__": test_profile_alpha_beta() diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 6a12f5bc848e..10fe9815541c 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -11,15 +11,16 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False class MLP(torch.nn.Module): - def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(4, 4) @@ -30,7 +31,6 @@ def forward(self, x): class relu(torch.nn.Module): - def __init__(self) -> None: super().__init__() self.relu = torch.nn.ReLU(inplace=True) @@ -40,7 +40,6 @@ def forward(self, x): class MyModule(torch.nn.Module): - def __init__(self): super().__init__() self.mlp1 = MLP() @@ -65,7 +64,7 @@ def forward(self, x, y): def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -87,26 +86,31 @@ def _run_act_ckpt_codegen(rank, world_size, port): # check ops are annotated with ckpt # also annotate the selected node for offloading - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] - offload_starts = ['mlp1_linear1'] + ckpt_nodes = ["mlp1_linear1", "mlp1_linear2", "relu_relu", "relu"] + offload_starts = ["mlp1_linear1"] for node in graph.nodes: if node.name in ckpt_nodes: - assert 'activation_checkpoint' in node.meta + assert "activation_checkpoint" in node.meta # annotate the selected node for offload if node.name in offload_starts: - node.meta['activation_offload'] = True + node.meta["activation_offload"] = True gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and # the offload option is correct - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1, data2) @@ -115,7 +119,7 @@ def _run_act_ckpt_codegen(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @rerun_if_address_is_in_use() def test_act_ckpt_codegen(): spawn(_run_act_ckpt_codegen, 1) @@ -123,7 +127,7 @@ def test_act_ckpt_codegen(): def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -144,25 +148,30 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): graph._python_code = python_code_with_activation_checkpoint.__get__(graph) # check ops are annotated with ckpt - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] - offload_starts = ['mlp1_linear1'] + ckpt_nodes = ["mlp1_linear1", "mlp1_linear2", "relu_relu", "relu"] + offload_starts = ["mlp1_linear1"] for node in graph.nodes: if node.name in ckpt_nodes: - assert 'activation_checkpoint' in node.meta + assert "activation_checkpoint" in node.meta # annotate the selected node for offload if node.name in offload_starts: - node.meta['activation_offload'] = True + node.meta["activation_offload"] = True gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and # the offload option is correct - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_2, False, y, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_3, False, y, relu, use_reentrant=True)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1, data2) @@ -171,12 +180,12 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): spawn(_run_act_ckpt_python_code_torch11, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py index ebcfb4d7b633..f1e87e5ed140 100644 --- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -9,15 +9,14 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version - from colossalai.fx.codegen import python_code_with_activation_checkpoint with_codegen = False class MyModule(torch.nn.Module): - def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(4, 4) @@ -33,7 +32,7 @@ def forward(self, x): def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -54,27 +53,34 @@ def _run_act_ckpt_codegen(rank, world_size, port): # annotate nested checkpoint for node in graph.nodes: if node.name == "linear1": - node.meta['activation_checkpoint'] = [0, 0, 0] + node.meta["activation_checkpoint"] = [0, 0, 0] continue if node.name == "linear2": - node.meta['activation_checkpoint'] = [0, 0, None] + node.meta["activation_checkpoint"] = [0, 0, None] if node.name == "linear3": - node.meta['activation_checkpoint'] = [0, 0, 1] + node.meta["activation_checkpoint"] = [0, 0, 1] if node.name == "linear4": - node.meta['activation_checkpoint'] = [0, 1, None] + node.meta["activation_checkpoint"] = [0, 1, None] if node.name == "linear5": - node.meta['activation_checkpoint'] = 1 + node.meta["activation_checkpoint"] = 1 gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1) @@ -83,14 +89,14 @@ def _run_act_ckpt_codegen(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") def test_act_ckpt_codegen(): spawn(_run_act_ckpt_codegen, 1) def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and run forward model = MyModule() @@ -111,27 +117,34 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): # annotate nested checkpoint for node in graph.nodes: if node.name == "linear1": - node.meta['activation_checkpoint'] = [0, 0, 0] + node.meta["activation_checkpoint"] = [0, 0, 0] continue if node.name == "linear2": - node.meta['activation_checkpoint'] = [0, 0, None] + node.meta["activation_checkpoint"] = [0, 0, None] if node.name == "linear3": - node.meta['activation_checkpoint'] = [0, 0, 1] + node.meta["activation_checkpoint"] = [0, 0, 1] if node.name == "linear4": - node.meta['activation_checkpoint'] = [0, 1, None] + node.meta["activation_checkpoint"] = [0, 1, None] if node.name == "linear5": - node.meta['activation_checkpoint'] = 1 + node.meta["activation_checkpoint"] = 1 gm = ColoGraphModule(model, graph) gm.recompile() # assert checkpoint function will be generated and - code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)' in code + code = graph.python_code("self").src + assert ( + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0, False, x, use_reentrant=False)" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_1, False, linear3, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0_0_1, False, linear2, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, x, use_reentrant=False)" + in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_1, False, linear4, use_reentrant=False)" + in code + ) # recompile and verify the outputs are consistent fx_out = gm(data1) @@ -140,12 +153,12 @@ def _run_act_ckpt_python_code_torch11(rank, world_size, port): gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): spawn(_run_act_ckpt_python_code_torch11, 1) -if __name__ == '__main__': +if __name__ == "__main__": _run_act_ckpt_codegen(rank=0) diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index dac59c23655e..da1e73ec3dfe 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -12,15 +12,16 @@ try: from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True except: # fall back to older pytorch version from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False class MyNet(torch.nn.Module): - def __init__(self) -> None: super().__init__() self.linear0 = torch.nn.Linear(4, 4) @@ -50,7 +51,6 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): - # test forward non_fx_out = model(data) fx_out = gm(data) @@ -66,7 +66,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T def _run_offload_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and input model = MyNet().cuda() @@ -83,37 +83,40 @@ def _run_offload_codegen(rank, world_size, port): # of input offload for node in graph.nodes: if node.name == "linear0": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear1": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear2": - node.meta['activation_offload'] = [1, True, True] + node.meta["activation_offload"] = [1, True, True] if node.name == "linear4": - node.meta['activation_offload'] = [2, False, True] + node.meta["activation_offload"] = [2, False, True] if node.name == "linear5": - node.meta['activation_checkpoint'] = [0] - node.meta['activation_offload'] = True + node.meta["activation_checkpoint"] = [0] + node.meta["activation_offload"] = True gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + assert ( + "def pack_hook_input(self, x):" in code + and "def unpack_hook(self, packed):" in code + and "def pack_hook_no_input(self, x):" in code + and "setattr(x, 'offload', True)" in code + and "setattr(linear3, 'offload', False)" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code + and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" + in code + ) _test_fwd_and_bwd(model, gm, data) gpc.destroy() -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @rerun_if_address_is_in_use() def test_act_ckpt_codegen(): spawn(_run_offload_codegen, 1) @@ -121,7 +124,7 @@ def test_act_ckpt_codegen(): def _run_offload_codegen_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currently - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model and input model = MyNet().cuda() @@ -139,31 +142,34 @@ def _run_offload_codegen_torch11(rank, world_size, port): # of input offload for node in graph.nodes: if node.name == "linear0": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear1": - node.meta['activation_offload'] = [0, True, False] + node.meta["activation_offload"] = [0, True, False] if node.name == "linear2": - node.meta['activation_offload'] = [1, True, True] + node.meta["activation_offload"] = [1, True, True] if node.name == "linear4": - node.meta['activation_offload'] = [2, False, True] + node.meta["activation_offload"] = [2, False, True] if node.name == "linear5": - node.meta['activation_checkpoint'] = [0] - node.meta['activation_offload'] = True + node.meta["activation_checkpoint"] = [0] + node.meta["activation_offload"] = True gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook_input(self, x):" in code and \ - "def unpack_hook(self, packed):" in code and \ - "def pack_hook_no_input(self, x):" in code and \ - "setattr(x, 'offload', True)" in code and \ - "setattr(linear3, 'offload', False)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ - "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + assert ( + "def pack_hook_input(self, x):" in code + and "def unpack_hook(self, packed):" in code + and "def pack_hook_no_input(self, x):" in code + and "setattr(x, 'offload', True)" in code + and "setattr(linear3, 'offload', False)" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code + and "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code + and "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code + and "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" + in code + ) _test_fwd_and_bwd(model, gm, data) gpc.destroy() diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index 96cf5198da10..efef368bdd45 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -1,4 +1,3 @@ -import pytest import torch import torch.nn as nn from torch.fx import GraphModule @@ -9,7 +8,6 @@ class Conv1D(nn.Module): - def __init__(self, nf, nx): super().__init__() self.nf = nf @@ -27,10 +25,9 @@ def forward(self, x): @clear_cache_before_run() def test_coloproxy(): - tracer = ColoTracer() model = Conv1D(3, 3) - input_sample = {'x': torch.rand(3, 3).to('meta')} + input_sample = {"x": torch.rand(3, 3).to("meta")} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) @@ -38,7 +35,7 @@ def test_coloproxy(): node = list(gm.graph.nodes)[0] proxy = ColoProxy(node=node, tracer=tracer) - proxy.meta_data = torch.empty(4, 2, device='meta') + proxy.meta_data = torch.empty(4, 2, device="meta") assert len(proxy) == 4 assert proxy.shape[0] == 4 and proxy.shape[1] == 2 @@ -47,5 +44,5 @@ def test_coloproxy(): assert proxy.size(0) == 4 -if __name__ == '__main__': +if __name__ == "__main__": test_coloproxy() diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index d3daadd71406..00721ca86ade 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -17,7 +17,6 @@ class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -36,7 +35,7 @@ def forward(self, x): @clear_cache_before_run() def test_comm_size_compute(): model = MLP(MODEL_DIM) - input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') + input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device="meta") gm = symbolic_trace(model) if is_compatible: input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device) @@ -49,5 +48,5 @@ def test_comm_size_compute(): assert comm_size == 128 -if __name__ == '__main__': +if __name__ == "__main__": test_comm_size_compute() diff --git a/tests/test_fx/test_graph_manipulation.py b/tests/test_fx/test_graph_manipulation.py index 175b69dd96fe..eece451a706f 100644 --- a/tests/test_fx/test_graph_manipulation.py +++ b/tests/test_fx/test_graph_manipulation.py @@ -1,15 +1,11 @@ import torch -from torch.fx import GraphModule -import colossalai from colossalai.fx import ColoTracer -from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top from colossalai.testing import clear_cache_before_run class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -43,11 +39,11 @@ def test_graph_manipulation(): assert leaf_nodes == set([l4, l5]) assert top_nodes == set([l1, l2]) for node in graph.nodes: - if node.op in ('placeholder', 'output'): - assert not hasattr(node, 'bfs_level') + if node.op in ("placeholder", "output"): + assert not hasattr(node, "bfs_level") else: assert node.bfs_level == compare_dict[node] -if __name__ == '__main__': +if __name__ == "__main__": test_graph_manipulation() diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index e490522dbf15..7fc7eb4df64b 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -13,35 +13,41 @@ aten = torch.ops.aten registered_meta = { - ('aten.convolution.default', True): [ # (aten ops, requires_backward) + ("aten.convolution.default", True): [ # (aten ops, requires_backward) (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)), (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)), (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)), - (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4)), - (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, - dilation=2), torch.rand(2, 3, 4, 4, 4)), + ( + nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4), + ), + ( + nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), + torch.rand(2, 3, 4, 4, 4), + ), ], - ('aten.native_batch_norm.default', True): [ + ("aten.native_batch_norm.default", True): [ (nn.BatchNorm1d(4), torch.rand(2, 4)), (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)), (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)), ], - ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),], - ('aten.avg_pool1d.default', True): [ + ("aten.native_layer_norm.default", True): [ + (nn.LayerNorm(4), torch.rand(1, 2, 3, 4)), + ], + ("aten.avg_pool1d.default", True): [ (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)), (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)), (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)), ], - ('aten.avg_pool2d.default', True): [ + ("aten.avg_pool2d.default", True): [ (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)), (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)), ], - ('aten.relu.default', True): [ + ("aten.relu.default", True): [ (nn.ReLU(), torch.rand(4, 3, 1, 2)), (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)), (nn.SiLU(), torch.rand(4, 3, 1, 2)), @@ -50,15 +56,20 @@ (nn.Sigmoid(), torch.rand(4, 3, 1, 2)), (nn.Tanh(), torch.rand(4, 3, 1, 2)), (nn.Hardswish(), torch.rand(4, 3, 1, 2)), - ] + ], } def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any: - assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.' - assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.' - assert tensor.stride() == meta_tensor.stride( - ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.' + assert ( + tensor.shape == meta_tensor.shape + ), f"the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match." + assert ( + tensor.dtype == meta_tensor.dtype + ), f"the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match." + assert ( + tensor.stride() == meta_tensor.stride() + ), f"the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match." def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any: @@ -72,7 +83,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac compare_all(x.grad, meta_x.grad) -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): @@ -80,5 +91,5 @@ def test_meta_aten(): run_and_compare(f, x, requires_backward) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_aten() diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py index 7aed6fd4597b..6091c4b6be2f 100644 --- a/tests/test_fx/test_meta/test_backward.py +++ b/tests/test_fx/test_meta/test_backward.py @@ -23,31 +23,40 @@ ] tmm_models = [ - tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, - tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, - tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, - tmm.swin_transformer.swin_base_patch4_window7_224 + tmm.resnest.resnest50d, + tmm.beit.beit_base_patch16_224, + tmm.cait.cait_s24_224, + tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, + tmm.vision_transformer.vit_base_patch16_224, + tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, + tmm.vgg.vgg11, + tmm.dpn.dpn68, + tmm.densenet.densenet121, + tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224, ] -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_torchvision_models(): for m in tm_models: model = m() - data = torch.rand(100000, 3, 224, 224, device='meta') - model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + data = torch.rand(100000, 3, 224, 224, device="meta") + model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward() -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_timm_models(): for m in tmm_models: model = m() - data = torch.rand(100000, 3, 224, 224, device='meta') - model(MetaTensor(data, fake_device=torch.device('cpu'))).sum().backward() + data = torch.rand(100000, 3, 224, 224, device="meta") + model(MetaTensor(data, fake_device=torch.device("cpu"))).sum().backward() -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models() test_timm_models() diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py index 61614f8a6623..ba9617a38380 100644 --- a/tests/test_fx/test_meta/test_meta_trace.py +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -23,31 +23,40 @@ ] tmm_models = [ - tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m, - tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224, - tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100, - tmm.swin_transformer.swin_base_patch4_window7_224 + tmm.resnest.resnest50d, + tmm.beit.beit_base_patch16_224, + tmm.cait.cait_s24_224, + tmm.efficientnet.efficientnetv2_m, + tmm.resmlp_12_224, + tmm.vision_transformer.vit_base_patch16_224, + tmm.deit_base_distilled_patch16_224, + tmm.convnext.convnext_base, + tmm.vgg.vgg11, + tmm.dpn.dpn68, + tmm.densenet.densenet121, + tmm.rexnet.rexnet_100, + tmm.swin_transformer.swin_base_patch4_window7_224, ] -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_torchvision_models_trace(): for m in tm_models: model = m() - data = torch.rand(1000, 3, 224, 224, device='meta') - graph = meta_trace(model, torch.device('cpu'), data) + data = torch.rand(1000, 3, 224, 224, device="meta") + meta_trace(model, torch.device("cpu"), data) -@pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@pytest.mark.skipif(not is_compatible_with_meta(), reason="torch version is lower than 1.12.0") @clear_cache_before_run() def test_timm_models_trace(): for m in tmm_models: model = m() - data = torch.rand(1000, 3, 224, 224, device='meta') - graph = meta_trace(model, torch.device('cpu'), data) + data = torch.rand(1000, 3, 224, 224, device="meta") + meta_trace(model, torch.device("cpu"), data) -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models_trace() test_timm_models_trace() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index a12512696a73..659949e87002 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -23,18 +23,18 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): @clear_cache_before_run() def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) - input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') + input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="meta") if is_compatible_with_meta(): - input_sample = MetaTensor(input_sample, fake_device='cpu') + input_sample = MetaTensor(input_sample, fake_device="cpu") orig_output = model(input_sample) gm = symbolic_trace(model) MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: - if node.op == 'placeholder': - meta_check(node.meta['tensor_meta'], input_sample) - if node.op == 'output': - meta_check(node.meta['tensor_meta'], orig_output) + if node.op == "placeholder": + meta_check(node.meta["tensor_meta"], input_sample) + if node.op == "output": + meta_check(node.meta["tensor_meta"], orig_output) -if __name__ == '__main__': +if __name__ == "__main__": test_meta_info_prop() diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py index 29135b45f997..6d890f59d5c5 100644 --- a/tests/test_fx/test_parallel_1d.py +++ b/tests/test_fx/test_parallel_1d.py @@ -13,7 +13,6 @@ class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -29,12 +28,12 @@ def forward(self, x): return x -CONFIG = dict(parallel=dict(tensor=dict(mode='1d', size=2))) +CONFIG = dict(parallel=dict(tensor=dict(mode="1d", size=2))) def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") input_tensor = torch.rand(2, 16).cuda() model = MLP(16).cuda() symbolic_traced = symbolic_trace(model) @@ -55,5 +54,5 @@ def test_1d(): spawn(check_layer, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_1d() diff --git a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py index 3afc6c97e2bb..b86c71db85c2 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py +++ b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py @@ -1,11 +1,12 @@ -import torch -from torch.fx import symbolic_trace -from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer import inspect import random + import numpy as np +import torch +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass MANUAL_SEED = 0 random.seed(MANUAL_SEED) @@ -26,7 +27,7 @@ def split_model_and_compare_output(model, data_gen): # tracing model tracer = ColoTracer() try: - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to("meta") for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") @@ -49,16 +50,16 @@ def split_model_and_compare_output(model, data_gen): output_part1 = model_part1(output_part0) else: if len(output_part0) > len(sig.parameters): - output_part0 = output_part0[:len(sig.parameters)] + output_part0 = output_part0[: len(sig.parameters)] output_part1 = model_part1(*output_part0) # get output tensor from HFOutput datastructure - if 'logits' in output: - output_to_compare = output['logits'] - elif 'prediction_logits' in output: - output_to_compare = output['prediction_logits'] + if "logits" in output: + output_to_compare = output["logits"] + elif "prediction_logits" in output: + output_to_compare = output["prediction_logits"] else: - output_to_compare = output['last_hidden_state'] + output_to_compare = output["last_hidden_state"] # compare output if isinstance(output_part1, torch.Tensor): diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py index 6ef861bdefbe..d15081b0b3ad 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_single_sentence_albert(): MODEL_LIST = [ transformers.AlbertModel, @@ -17,12 +17,14 @@ def test_single_sentence_albert(): transformers.AlbertForTokenClassification, ] - config = transformers.AlbertConfig(vocab_size=100, - embedding_size=128, - hidden_size=128, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=256) + config = transformers.AlbertConfig( + vocab_size=100, + embedding_size=128, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + ) def data_gen(): input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) @@ -36,5 +38,5 @@ def data_gen(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_single_sentence_albert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py index a7550413fac8..3588033d1ecd 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_single_sentence_bert(): MODEL_LIST = [ transformers.BertModel, @@ -18,11 +18,9 @@ def test_single_sentence_bert(): transformers.BertForTokenClassification, ] - config = transformers.BertConfig(vocab_size=100, - hidden_size=128, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=256) + config = transformers.BertConfig( + vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4, intermediate_size=256 + ) def data_gen(): input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) @@ -36,5 +34,5 @@ def data_gen(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_single_sentence_bert() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py index 6181c5c0706a..d2533aea4003 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py @@ -9,14 +9,14 @@ NUM_CHUNKS = 1 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_gpt(): MODEL_LIST = [ transformers.GPT2Model, transformers.GPT2LMHeadModel, transformers.GPT2DoubleHeadsModel, transformers.GPT2ForTokenClassification, - # transformers.GPT2ForSequenceClassification, # not supported yet + # transformers.GPT2ForSequenceClassification, # not supported yet ] config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=8) @@ -32,5 +32,5 @@ def data_gen(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_gpt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py index 1a9b36be82bd..e67628d10364 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_opt(): MODEL_LIST = [ transformers.OPTModel, @@ -27,5 +27,5 @@ def data_gen(): split_model_and_compare_output(model, data_gen) -if __name__ == '__main__': +if __name__ == "__main__": test_opt() diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py index 16d0163746b3..dc36fdb13152 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_t5(): MODEL_LIST = [ transformers.T5Model, @@ -39,5 +39,5 @@ def data_gen_for_encoder_only(): split_model_and_compare_output(model, data_gen_func) -if __name__ == '__main__': +if __name__ == "__main__": test_t5() diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py index 6fb1f6f4bb23..c4fe5547ed8d 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -4,9 +4,8 @@ from timm_utils import split_model_and_compare_output -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_timm_models_without_control_flow(): - MODEL_LIST = [ tm.resnest.resnest50d, tm.beit.beit_base_patch16_224, @@ -25,24 +24,28 @@ def test_timm_models_without_control_flow(): split_model_and_compare_output(model, data) -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True MODEL_LIST_WITH_CONTROL_FLOW = [ - tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100, - tm.swin_transformer.swin_base_patch4_window7_224 + tm.convnext.convnext_base, + tm.vgg.vgg11, + tm.dpn.dpn68, + tm.densenet.densenet121, + tm.rexnet.rexnet_100, + tm.swin_transformer.swin_base_patch4_window7_224, ] data = torch.rand(2, 3, 224, 224) - meta_args = {'x': data.to('meta')} + meta_args = {"x": data.to("meta")} for model_cls in MODEL_LIST_WITH_CONTROL_FLOW: model = model_cls() split_model_and_compare_output(model, data, meta_args) -if __name__ == '__main__': +if __name__ == "__main__": test_timm_models_without_control_flow() test_timm_models_with_control_flow() diff --git a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py index aa870e5c7a65..e1182c8d4978 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py +++ b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py @@ -1,11 +1,12 @@ -import torch -from torch.fx import symbolic_trace -from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer import inspect import random + import numpy as np +import torch +from torch.fx import GraphModule + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass MANUAL_SEED = 0 random.seed(MANUAL_SEED) @@ -46,6 +47,6 @@ def split_model_and_compare_output(model, data, meta_args=None): output_part1 = model_part1(output_part0) else: if len(output_part0) > len(sig.parameters): - output_part0 = output_part0[:len(sig.parameters)] + output_part0 = output_part0[: len(sig.parameters)] output_part1 = model_part1(*output_part0) assert output.equal(output_part1) diff --git a/tests/test_fx/test_pipeline/test_topo/test_topo.py b/tests/test_fx/test_pipeline/test_topo/test_topo.py index 16da56250dc3..7c420ef2385a 100644 --- a/tests/test_fx/test_pipeline/test_topo/test_topo.py +++ b/tests/test_fx/test_pipeline/test_topo/test_topo.py @@ -7,7 +7,7 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@pytest.mark.skip("ShapeProp is not compatible with PyTorch 1.11.0") def test_opt(): MODEL_LIST = [ MLP, @@ -15,10 +15,7 @@ def test_opt(): ] CONFIGS = [ - { - 'dim': 10, - 'layers': 12 - }, + {"dim": 10, "layers": 12}, transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), ] @@ -45,5 +42,5 @@ def data_gen_OPT(): check_topo(top_mod, topo) -if __name__ == '__main__': +if __name__ == "__main__": test_opt() diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py index db6cadfc544c..6a69181a6d26 100644 --- a/tests/test_fx/test_pipeline/test_topo/topo_utils.py +++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py @@ -6,7 +6,7 @@ from colossalai.fx import ColoTracer from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass -from colossalai.legacy.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo +from colossalai.legacy.pipeline.middleware import Partition, Topo from colossalai.legacy.pipeline.middleware.adaptor import get_fx_topology MANUAL_SEED = 0 @@ -16,11 +16,10 @@ class MLP(torch.nn.Module): - def __init__(self, config={}): super().__init__() - dim = config['dim'] - layers = config['layers'] + dim = config["dim"] + layers = config["layers"] self.layers = torch.nn.ModuleList() for _ in range(layers): @@ -41,7 +40,7 @@ def split_model_and_get_DAG(model, data_gen): # tracing model tracer = ColoTracer() try: - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to("meta") for k, v in kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") @@ -55,7 +54,7 @@ def split_model_and_get_DAG(model, data_gen): topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) + setattr(submodule, "_topo", topo) return top_module, split_submodules[0]._topo @@ -64,7 +63,7 @@ def check_input(top_module, input_partition: Partition): partition_output = input_partition.get_output_vals() arg_pos = 0 for node in top_module.graph.nodes: - if node.op == 'placeholder': + if node.op == "placeholder": cur_checkee = partition_output[arg_pos] to_partition_and_offset = cur_checkee.get() assert len(to_partition_and_offset) == len(node.users.keys()) @@ -80,7 +79,7 @@ def check_submod(top_module, part_id, mid_partition: Partition): cnt = 1 cur_node = None for node in top_module.graph.nodes: - if node.name.startswith('submod'): + if node.name.startswith("submod"): cnt += 1 if cnt == part_id: cur_node = node diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py index 5d47be2c7bea..063e51309503 100644 --- a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -19,14 +19,21 @@ torch.backends.cudnn.deterministic = True -@pytest.mark.skip('balance split v2 is not ready') +@pytest.mark.skip("balance split v2 is not ready") def test_torchvision_models(): MODEL_LIST = [ - tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, - tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5 + tm.vgg11, + tm.resnet18, + tm.densenet121, + tm.mobilenet_v3_small, + tm.resnext50_32x4d, + tm.wide_resnet50_2, + tm.regnet_x_16gf, + tm.efficientnet_b0, + tm.mnasnet0_5, ] - if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + if version.parse(torchvision.__version__) >= version.parse("0.12.0"): MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small]) tracer = ColoTracer() @@ -57,10 +64,10 @@ def test_torchvision_models(): output_part1 = model_part1(output_part0) else: if len(output_part0) > len(sig.parameters): - output_part0 = output_part0[:len(sig.parameters)] + output_part0 = output_part0[: len(sig.parameters)] output_part1 = model_part1(*output_part0) assert output.equal(output_part1) -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models() diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py index 1078dac9db7c..7a5a397500bb 100644 --- a/tests/test_fx/test_pipeline_passes.py +++ b/tests/test_fx/test_pipeline_passes.py @@ -1,10 +1,6 @@ -import pytest import torch -import torch.nn as nn from torch.fx import symbolic_trace -import colossalai -import colossalai.nn as col_nn from colossalai.fx.passes.adding_split_node_pass import ( balanced_split_pass, balanced_split_pass_v2, @@ -19,7 +15,6 @@ class MLP(torch.nn.Module): - def __init__(self, dim: int): super().__init__() self.linear1 = torch.nn.Linear(dim, dim) @@ -53,5 +48,5 @@ def test_pipeline_passes(): pipeline_pass_test_helper(model, data, uniform_split_pass) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_passes() diff --git a/tests/test_fx/test_profiler/gpt_utils.py b/tests/test_fx/test_profiler/gpt_utils.py index aec32268484f..9e4214876ba7 100644 --- a/tests/test_fx/test_profiler/gpt_utils.py +++ b/tests/test_fx/test_profiler/gpt_utils.py @@ -1,26 +1,29 @@ -import torch import torch.nn as nn from transformers import GPT2Config, GPT2LMHeadModel class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257, - checkpoint=False): + def __init__( + self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50257, + checkpoint=False, + ): super().__init__() self.checkpoint = checkpoint self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size)) + GPT2Config( + n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + ) + ) if checkpoint: self.model.gradient_checkpointing_enable() @@ -30,7 +33,6 @@ def forward(self, input_ids, attention_mask): class GPTLMLoss(nn.Module): - def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py index b5a6bbe8bf18..28409696ca55 100644 --- a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -1,9 +1,9 @@ -from typing import Optional, Tuple, Union +from typing import Tuple import torch import torch.fx import torchvision.models as tm -from gpt_utils import gpt2_medium, gpt2_xl +from gpt_utils import gpt2_medium from torch.fx import symbolic_trace from colossalai.fx.passes.meta_info_prop import MetaInfoProp @@ -33,18 +33,18 @@ def extract_forward_flops(gm: torch.fx.GraphModule): fwd_flop = 0 bwd_flop = 0 for node in gm.graph.nodes: - fwd_flop += node.meta.get('fwd_flop', 0) - bwd_flop += node.meta.get('bwd_flop', 0) + fwd_flop += node.meta.get("fwd_flop", 0) + bwd_flop += node.meta.get("bwd_flop", 0) return fwd_flop, bwd_flop -def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device='cuda'): +def gen_tm_data(batch_size: int, shape: Tuple[int, int, int], device="cuda"): data = torch.rand(batch_size, *shape, device=device) label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) return data, label -def gen_gpt_data(batch_size, seq_len, vocab_size, device='cpu'): +def gen_gpt_data(batch_size, seq_len, vocab_size, device="cpu"): input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) attention_mask = torch.ones_like(input_ids, device=device) return input_ids, attention_mask @@ -96,7 +96,7 @@ def run_gpt_forward(gm: torch.fx.GraphModule): param_mem += torch.cuda.memory_allocated(device="cuda:0") / 1024**2 for n in range(NUM_STEPS): torch.cuda.reset_peak_memory_stats() - data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='cuda:0') + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device="cuda:0") # If we need to dive deep into the memory usage by # inspecting `saved_tensor_hooks` @@ -125,21 +125,56 @@ def run_gpt_forward(gm: torch.fx.GraphModule): return forward_mem, param_mem -@run_on_environment_flag(name='FX_PROFILER') +@run_on_environment_flag(name="FX_PROFILER") @clear_cache_before_run() def test_meta_info_prop(): for m in [ - tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121, - tm.densenet161, tm.densenet169, tm.densenet201, tm.convnext_tiny, tm.convnext_small, tm.convnext_base, - tm.convnext_large, tm.wide_resnet50_2, tm.wide_resnet101_2, tm.regnet_x_16gf, tm.mnasnet0_5, - tm.efficientnet_b0, tm.shufflenet_v2_x0_5, tm.shufflenet_v2_x1_0, tm.shufflenet_v2_x1_5, - tm.shufflenet_v2_x2_0, tm.mobilenet_v2, tm.mobilenet_v3_small, tm.mobilenet_v3_large, tm.resnext50_32x4d, - tm.resnext101_32x8d, tm.resnext101_64x4d, tm.vit_b_16, tm.vit_b_32, tm.vit_h_14, tm.vit_l_16, tm.vit_l_32, - tm.vgg11, tm.vgg11_bn, tm.vgg13, tm.vgg13_bn, tm.vgg16, tm.vgg16_bn, tm.vgg19, tm.vgg19_bn + tm.alexnet, + tm.resnet18, + tm.resnet34, + tm.resnet50, + tm.resnet101, + tm.resnet152, + tm.densenet121, + tm.densenet161, + tm.densenet169, + tm.densenet201, + tm.convnext_tiny, + tm.convnext_small, + tm.convnext_base, + tm.convnext_large, + tm.wide_resnet50_2, + tm.wide_resnet101_2, + tm.regnet_x_16gf, + tm.mnasnet0_5, + tm.efficientnet_b0, + tm.shufflenet_v2_x0_5, + tm.shufflenet_v2_x1_0, + tm.shufflenet_v2_x1_5, + tm.shufflenet_v2_x2_0, + tm.mobilenet_v2, + tm.mobilenet_v3_small, + tm.mobilenet_v3_large, + tm.resnext50_32x4d, + tm.resnext101_32x8d, + tm.resnext101_64x4d, + tm.vit_b_16, + tm.vit_b_32, + tm.vit_h_14, + tm.vit_l_16, + tm.vit_l_32, + tm.vgg11, + tm.vgg11_bn, + tm.vgg13, + tm.vgg13_bn, + tm.vgg16, + tm.vgg16_bn, + tm.vgg19, + tm.vgg19_bn, ]: model = m().cuda() model.train() - data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device='meta'), fake_device='cuda:0') + data = MetaTensor(torch.rand(int(TM_BATCH_SIZE), 3, 224, 224, device="meta"), fake_device="cuda:0") gm = symbolic_trace(model) interp = MetaInfoProp(gm) interp.propagate(data) @@ -150,22 +185,22 @@ def test_meta_info_prop(): concrete_forward_mem, concrete_param_mem = run_tm_forward(gm) print( - f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + f"|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|" ) del model, gm -@run_on_environment_flag(name='FX_PROFILER') +@run_on_environment_flag(name="FX_PROFILER") @clear_cache_before_run() def test_gpt_meta_info_prop(): for m in [gpt2_medium]: model = m().cuda() model.train() - data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device='meta') - graph = ColoTracer().trace(model, meta_args={'input_ids': data, 'attention_mask': mask}) + data, mask = gen_gpt_data(GPT_BATCH_SIZE, 1024, 50257, device="meta") + graph = ColoTracer().trace(model, meta_args={"input_ids": data, "attention_mask": mask}) gm = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm) - interp.propagate(MetaTensor(data, fake_device='cuda:0'), MetaTensor(mask, fake_device='cuda:0')) + interp.propagate(MetaTensor(data, fake_device="cuda:0"), MetaTensor(mask, fake_device="cuda:0")) model.cpu() fwd_flop, bwd_flop = extract_forward_flops(gm) @@ -174,11 +209,11 @@ def test_gpt_meta_info_prop(): meta_forward_mem, meta_param_mem = extract_forward_mem(gm) print( - f'|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|' + f"|{m.__name__}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|fwd_flop={fwd_flop / 1e9:.3f}GFLOPs|bwd_flop={bwd_flop / 1e9:.3f}GFLOPs|" ) del model, gm -if __name__ == '__main__': +if __name__ == "__main__": test_meta_info_prop() test_gpt_meta_info_prop() diff --git a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py index 632ab8c09750..e7dcf07aafb4 100644 --- a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py +++ b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn from torch.fx import GraphModule from torch.utils.checkpoint import checkpoint @@ -8,7 +7,6 @@ class MLP(torch.nn.Module): - def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(4, 4) @@ -22,7 +20,6 @@ def forward(self, x): # Simple module for demonstration class MyModule(torch.nn.Module): - def __init__(self): super().__init__() self.mlp_1 = MLP() @@ -46,20 +43,20 @@ def test_activation_checkpoint_annotation(): gm = GraphModule(module, graph) for node in gm.graph.nodes: - if node.name in ['mlp_1_linear1', 'mlp_1_linear2']: - assert node.meta.get('activation_checkpoint', -1) == 0 + if node.name in ["mlp_1_linear1", "mlp_1_linear2"]: + assert node.meta.get("activation_checkpoint", -1) == 0 for node in gm.graph.nodes: - if node.name in ['mlp_2_linear1', 'mlp_2_linear2']: - assert node.meta.get('activation_checkpoint', -1) == 1 + if node.name in ["mlp_2_linear1", "mlp_2_linear2"]: + assert node.meta.get("activation_checkpoint", -1) == 1 tracer = ColoTracer(trace_act_ckpt=False) graph = tracer.trace(module) gm = GraphModule(module, graph) for node in gm.graph.nodes: - assert not hasattr(node, 'activation_checkpoint') + assert not hasattr(node, "activation_checkpoint") -if __name__ == '__main__': +if __name__ == "__main__": test_activation_checkpoint_annotation() diff --git a/tests/test_fx/test_tracer/test_bias_addition_module.py b/tests/test_fx/test_tracer/test_bias_addition_module.py index 2f88d8c784e8..e53894bdfd71 100644 --- a/tests/test_fx/test_tracer/test_bias_addition_module.py +++ b/tests/test_fx/test_tracer/test_bias_addition_module.py @@ -5,7 +5,6 @@ class LinearModel(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.linear = torch.nn.Linear(in_features, out_features) @@ -18,13 +17,11 @@ def forward(self, x): class ConvModel(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True): super().__init__() - self.conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=bias) + self.conv = torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias + ) def forward(self, x): x = self.conv(x) @@ -45,7 +42,7 @@ def test_linear_module(): # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(3, 3).to('meta')}) + graph = tracer.trace(root=model, meta_args={"x": torch.rand(3, 3).to("meta")}) # def forward(self, x : torch.Tensor): # linear_weight = self.linear.weight # linear_bias = self.linear.bias @@ -57,9 +54,9 @@ def test_linear_module(): gm.recompile() node_list = list(graph.nodes) for node in node_list: - if node.op == 'output': + if node.op == "output": continue - assert hasattr(node, '_meta_data') + assert hasattr(node, "_meta_data") weight_node = node_list[1] bias_node = node_list[2] linear_node = node_list[3] @@ -83,7 +80,7 @@ def test_conv_module(): # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) + graph = tracer.trace(root=model, meta_args={"x": torch.rand(4, 3, 64, 64).to("meta")}) # def forward(self, x : torch.Tensor): # conv_weight = self.conv.weight # conv_bias = self.conv.bias @@ -97,9 +94,9 @@ def test_conv_module(): gm.recompile() node_list = list(graph.nodes) for node in node_list: - if node.op == 'output': + if node.op == "output": continue - assert hasattr(node, '_meta_data') + assert hasattr(node, "_meta_data") weight_node = node_list[1] bias_node = node_list[2] conv_node = node_list[3] @@ -112,6 +109,6 @@ def test_conv_module(): assert add_node._meta_data.shape == (4, 6, 63, 63) -if __name__ == '__main__': +if __name__ == "__main__": test_linear_module() test_conv_module() diff --git a/tests/test_fx/test_tracer/test_control_flow.py b/tests/test_fx/test_tracer/test_control_flow.py index 820729dadb3e..f0c261c39db5 100644 --- a/tests/test_fx/test_tracer/test_control_flow.py +++ b/tests/test_fx/test_tracer/test_control_flow.py @@ -7,7 +7,6 @@ class ControlFlowModel(nn.Module): - def __init__(self): super().__init__() self.linear1 = nn.Linear(10, 10) @@ -27,16 +26,12 @@ def forward(self, x, y): def test_control_flow(): model = ControlFlowModel() tracer = Tracer() - graph_branch_true = tracer.trace(model, - meta_args={ - 'x': torch.rand(4, 10, device='meta'), - 'y': torch.rand(4, 10, device='meta') - }) - graph_branch_false = tracer.trace(model, - meta_args={ - 'x': torch.rand(10, device='meta'), - 'y': torch.rand(4, 10, device='meta') - }) + graph_branch_true = tracer.trace( + model, meta_args={"x": torch.rand(4, 10, device="meta"), "y": torch.rand(4, 10, device="meta")} + ) + graph_branch_false = tracer.trace( + model, meta_args={"x": torch.rand(10, device="meta"), "y": torch.rand(4, 10, device="meta")} + ) gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__) gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__) @@ -56,5 +51,5 @@ def test_control_flow(): assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y)) -if __name__ == '__main__': +if __name__ == "__main__": test_control_flow() diff --git a/tests/test_fx/test_tracer/test_functional_conv.py b/tests/test_fx/test_tracer/test_functional_conv.py index a552e905223d..63f9721e2a65 100644 --- a/tests/test_fx/test_tracer/test_functional_conv.py +++ b/tests/test_fx/test_tracer/test_functional_conv.py @@ -47,5 +47,5 @@ def test_conv(): assert out_transpose_3d.shape == patched_out_transpose_3d.shape -if __name__ == '__main__': +if __name__ == "__main__": test_conv() diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index e6f8df2e0af7..4828bb0302c8 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -1,9 +1,6 @@ from typing import List import torch -from numpy import isin -from torch.fx import GraphModule -from torch.utils._pytree import tree_flatten # from colossalai.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace @@ -20,7 +17,7 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non inputs = {k: v for k, v in inputs.items() if k not in ignore_data} try: - meta_args = {k: v.to('meta') for k, v in inputs.items()} + meta_args = {k: v.to("meta") for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) except Exception as e: @@ -35,4 +32,4 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non if torch.is_tensor(fx_out[k]): assert torch.equal( fx_out[k], non_fx_out[k] - ), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}' + ), f"{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}" diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index a1470400ad82..fb093821e488 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -10,15 +10,15 @@ SEQ_LENGTH = 16 -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_albert(): - sub_registry = model_zoo.get_sub_registry('transformers_albert') + sub_registry = model_zoo.get_sub_registry("transformers_albert") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() trace_model_and_compare_output(model, data_gen_fn) -if __name__ == '__main__': +if __name__ == "__main__": test_albert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 7773de480302..91f7b9764e6e 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -7,17 +7,17 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_bert(): - sub_registry = model_zoo.get_sub_registry('transformers_bert') + sub_registry = model_zoo.get_sub_registry("transformers_bert") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() if model.__class__.__name__ == "BertForQuestionAnswering": continue - trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label']) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "next_sentence_label"]) -if __name__ == '__main__': +if __name__ == "__main__": test_bert() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index ac87a7fcb13b..95a464fa0534 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -22,7 +22,7 @@ def trace_and_compare(model_cls, data, output_fn): model.eval() concrete_args = {k: v for k, v in data.items() if not torch.is_tensor(v)} - meta_args = {k: v.to('meta') for k, v in data.items() if torch.is_tensor(v)} + meta_args = {k: v.to("meta") for k, v in data.items() if torch.is_tensor(v)} gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args) # run forward @@ -40,12 +40,12 @@ def assert_fn(ta, tb): assert_dict(transformed_fx_out, transformed_non_fx_out, assert_fn) -@pytest.mark.skip(reason='cannot pass this test yet') +@pytest.mark.skip(reason="cannot pass this test yet") @clear_cache_before_run() def test_diffusers(): seed_all(9091, cuda_deterministic=True) - sub_model_zoo = model_zoo.get_sub_registry('diffusers') + sub_model_zoo = model_zoo.get_sub_registry("diffusers") for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() @@ -58,12 +58,12 @@ def test_diffusers(): def test_torch_diffusers(): seed_all(65535, cuda_deterministic=True) - sub_model_zoo = model_zoo.get_sub_registry('diffusers') + sub_model_zoo = model_zoo.get_sub_registry("diffusers") for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() model = model_fn() - output = model(**data) + model(**data) torch.cuda.synchronize() print(f"{name:40s} √") diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 1cd3b90db917..7bd8a726f1ac 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -7,10 +7,10 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_gpt(): - sub_registry = model_zoo.get_sub_registry('transformers_gpt') + sub_registry = model_zoo.get_sub_registry("transformers_gpt") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() @@ -18,11 +18,11 @@ def test_gpt(): # TODO(ver217): support the following models # 1. GPT2DoubleHeadsModel # as they are not supported, let's skip them - if model.__class__.__name__ in ['GPT2DoubleHeadsModel', 'GPT2ForQuestionAnswering']: + if model.__class__.__name__ in ["GPT2DoubleHeadsModel", "GPT2ForQuestionAnswering"]: continue - trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels"]) -if __name__ == '__main__': +if __name__ == "__main__": test_gpt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index c68b89e82fbe..5f7525d5707b 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -7,14 +7,14 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_opt(): - sub_registry = model_zoo.get_sub_registry('transformers_opt') + sub_registry = model_zoo.get_sub_registry("transformers_opt") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'start_positions', 'end_positions']) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels", "start_positions", "end_positions"]) -if __name__ == '__main__': +if __name__ == "__main__": test_opt() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 45e06bc2bbb0..6ccbb14e3d96 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -7,20 +7,20 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_t5(): - sub_registry = model_zoo.get_sub_registry('transformers_t5') + sub_registry = model_zoo.get_sub_registry("transformers_t5") for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): if name == "transformers_t5_for_conditional_generation": # cannot trace for loss function yet # so we use a data gen which does not produce labels - data_gen_fn = sub_registry.get('transformers_t5')[1] + data_gen_fn = sub_registry.get("transformers_t5")[1] model = model_fn() - trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=["labels"]) -if __name__ == '__main__': +if __name__ == "__main__": test_t5() diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index ef778e21801a..fe66cbd0ffcc 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -36,12 +36,12 @@ def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape) @clear_cache_before_run() def test_linear(): # test linear patch can produce the meta output with correct shape - data = torch.rand(2, 4, device='meta') + data = torch.rand(2, 4, device="meta") module = torch.nn.Linear(4, 2) _assert_output_shape(data, module, patched_module.torch_nn_linear, False, torch.Size([2, 2])) # test if the linear patch can catch exception when dimension does not match - data = torch.rand(2, 2, device='meta') + data = torch.rand(2, 2, device="meta") _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) @@ -51,20 +51,20 @@ def test_rnn(): data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) module = torch.nn.RNN(10, 20, 2) output, hn = module(*data) - meta_data = (torch.randn(5, 3, 10).to('meta'), torch.randn(2, 3, 20).to('meta')) + meta_data = (torch.randn(5, 3, 10).to("meta"), torch.randn(2, 3, 20).to("meta")) _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, False, (output.shape, hn.shape)) # test if the rnn patch can catch exception when dimension does not match data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) module = torch.nn.RNN(10, 20, 2) output, hn = module(*data) - meta_data = (torch.randn(5, 3, 1).to('meta'), torch.randn(2, 3, 20).to('meta')) + meta_data = (torch.randn(5, 3, 1).to("meta"), torch.randn(2, 3, 20).to("meta")) _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) @clear_cache_before_run() def test_embedding(): - data = torch.rand(2, 4, device='meta') + data = torch.rand(2, 4, device="meta") # test layernorm ln = torch.nn.LayerNorm(4) @@ -76,67 +76,71 @@ def test_embedding(): # test batch norm 1d bn1d = torch.nn.BatchNorm1d(4) - data = torch.rand(2, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) - - data = torch.rand(2, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) - - data = torch.rand(2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) - - data = torch.rand(1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn1d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=True, - output_shape=None) + data = torch.rand(2, 4, device="meta") + _assert_output_shape( + data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) + + data = torch.rand(2, 4, device="meta") + _assert_output_shape( + data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) + + data = torch.rand(2, 3, 4, device="meta") + _assert_output_shape( + data=data, + module=bn1d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) + + data = torch.rand(1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, module=bn1d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None + ) # test batch norm 2d bn2d = torch.nn.BatchNorm2d(4) - data = torch.rand(1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn2d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) + data = torch.rand(1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, + module=bn2d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) - data = torch.rand(2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn2d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=True, - output_shape=None) + data = torch.rand(2, 3, 4, device="meta") + _assert_output_shape( + data=data, module=bn2d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None + ) # # test batch size 3d bn3d = torch.nn.BatchNorm3d(4) - data = torch.rand(1, 1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn3d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=False, - output_shape=data.shape) + data = torch.rand(1, 1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, + module=bn3d, + patch_fn=patched_module.torch_nn_normalize, + expect_exception=False, + output_shape=data.shape, + ) - data = torch.rand(1, 2, 3, 4, device='meta') - _assert_output_shape(data=data, - module=bn3d, - patch_fn=patched_module.torch_nn_normalize, - expect_exception=True, - output_shape=None) + data = torch.rand(1, 2, 3, 4, device="meta") + _assert_output_shape( + data=data, module=bn3d, patch_fn=patched_module.torch_nn_normalize, expect_exception=True, output_shape=None + ) @clear_cache_before_run() @@ -146,35 +150,38 @@ def test_conv1d(): conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=conv1d, - patch_fn=patched_module.torch_nn_conv1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv1d = torch.nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=conv1d, - patch_fn=patched_module.torch_nn_conv1d, - expect_exception=False, - output_shape=materialized_output.shape) - - conv1d = torch.nn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=2, - padding=1, - dilation=2, - padding_mode='reflect') + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) + + conv1d = torch.nn.Conv1d( + in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" + ) materialized_output = conv1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=conv1d, - patch_fn=patched_module.torch_nn_conv1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=conv1d, + patch_fn=patched_module.torch_nn_conv1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) def test_conv2d(): @@ -182,40 +189,45 @@ def test_conv2d(): data = torch.rand(2, 3, 4, 4) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv2d = torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) - - conv2d = torch.nn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=2, - padding=1, - dilation=2, - padding_mode='reflect') + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) + + conv2d = torch.nn.Conv2d( + in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" + ) materialized_output = conv2d(data) - _assert_output_shape(data=data, - module=conv2d, - patch_fn=patched_module.torch_nn_conv2d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv2d, + patch_fn=patched_module.torch_nn_conv2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -224,40 +236,45 @@ def test_conv3d(): data = torch.rand(2, 3, 4, 4, 4) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) conv3d = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) - - conv3d = torch.nn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=2, - padding=1, - dilation=2, - padding_mode='reflect') + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) + + conv3d = torch.nn.Conv3d( + in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2, padding_mode="reflect" + ) materialized_output = conv3d(data) - _assert_output_shape(data=data, - module=conv3d, - patch_fn=patched_module.torch_nn_conv3d, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=conv3d, + patch_fn=patched_module.torch_nn_conv3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -267,21 +284,25 @@ def test_conv_transpose1d(): convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans1d, - patch_fn=patched_module.torch_nn_convtranspose1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) convtrans1d = torch.nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans1d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans1d, - patch_fn=patched_module.torch_nn_convtranspose1d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans1d, + patch_fn=patched_module.torch_nn_convtranspose1d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -291,21 +312,25 @@ def test_conv_transpose2d(): convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans2d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans2d, - patch_fn=patched_module.torch_nn_convtranspose2d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) convtrans2d = torch.nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans2d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans2d, - patch_fn=patched_module.torch_nn_convtranspose2d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans2d, + patch_fn=patched_module.torch_nn_convtranspose2d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() @@ -315,46 +340,56 @@ def test_conv_transpose3d(): convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2) materialized_output = convtrans3d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans3d, - patch_fn=patched_module.torch_nn_convtranspose3d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) convtrans3d = torch.nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1) materialized_output = convtrans3d(data) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - module=convtrans3d, - patch_fn=patched_module.torch_nn_convtranspose3d, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, + module=convtrans3d, + patch_fn=patched_module.torch_nn_convtranspose3d, + expect_exception=False, + output_shape=materialized_output.shape, + ) @clear_cache_before_run() def test_pool1d(): - combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], - [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] + combinations = [ + [torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], + [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d], + ] - for (layer_cls, patch_func) in combinations: + for layer_cls, patch_func in combinations: pooler = layer_cls(kernel_size=3) data = torch.rand(2, 3, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) data = torch.rand(2, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) data = torch.rand(2, 3, 4, 4) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) @@ -362,29 +397,35 @@ def test_pool1d(): @clear_cache_before_run() def test_pool2d(): - combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], - [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]] + combinations = [ + [torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], + [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d], + ] - for (layer_cls, patch_func) in combinations: + for layer_cls, patch_func in combinations: pooler = layer_cls(kernel_size=3) # test max pool 3d data = torch.rand(2, 3, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 3, 4, 4, 4) @@ -393,29 +434,35 @@ def test_pool2d(): @clear_cache_before_run() def test_pool3d(): - combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], - [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]] + combinations = [ + [torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], + [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d], + ] - for (layer_cls, patch_func) in combinations: + for layer_cls, patch_func in combinations: pooler = layer_cls(kernel_size=3) # test max pool 3d data = torch.rand(2, 3, 4, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 4, 4, 4) materialized_output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=materialized_output.shape) + _assert_output_shape( + data=data, + module=pooler, + patch_fn=patch_func, + expect_exception=False, + output_shape=materialized_output.shape, + ) # test max pool 3d data = torch.rand(2, 3, 4) @@ -430,19 +477,15 @@ def test_adaptive_pooling_1d(): data = torch.rand(3, 4) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4, 5) _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) @@ -458,19 +501,15 @@ def test_adaptive_pooling_2d(): data = torch.rand(2, 3, 4) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4, 5) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) @clear_cache_before_run() @@ -483,16 +522,12 @@ def test_adaptive_pooling_3d(): data = torch.rand(2, 3, 4, 5) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) data = torch.rand(2, 3, 4, 5, 6) output = pooler(data) - _assert_output_shape(data=data, - module=pooler, - patch_fn=patch_func, - expect_exception=False, - output_shape=output.shape) + _assert_output_shape( + data=data, module=pooler, patch_fn=patch_func, expect_exception=False, output_shape=output.shape + ) diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py index e0c5f560c49e..37c2333c0982 100644 --- a/tests/test_fx/test_tracer/test_patched_op.py +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -33,38 +33,34 @@ def test_repeat_interleave(): data = torch.tensor([1, 2, 3]) materialized_output = torch.repeat_interleave(data, repeats=2) repeat_interleave = partial(patch_fn, repeats=2) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape + ) data = torch.tensor([[1, 2], [3, 4]]) materialized_output = torch.repeat_interleave(data, repeats=3, dim=1) repeat_interleave = partial(patch_fn, repeats=3, dim=1) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape + ) data = torch.tensor([[1, 2], [3, 4]]) materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=-1) repeat_interleave = partial(patch_fn, repeats=torch.tensor([1, 2]), dim=-1) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=False, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=False, output_shape=materialized_output.shape + ) data = torch.tensor([[1, 2], [3, 4]]) materialized_output = torch.repeat_interleave(data, repeats=torch.tensor([1, 2]), dim=0) repeat_interleave = partial(patch_fn, repeats=[1, 2], dim=0) - meta_data = data.to('meta') - _assert_output_shape(data=meta_data, - patch_fn=repeat_interleave, - expect_exception=True, - output_shape=materialized_output.shape) + meta_data = data.to("meta") + _assert_output_shape( + data=meta_data, patch_fn=repeat_interleave, expect_exception=True, output_shape=materialized_output.shape + ) @clear_cache_before_run() diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 98433b8f7c3b..2b3f3e039baf 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -20,7 +20,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): # 1. ConViT # 2. NormFreeNet # as they are not supported, let's skip them - if model.__class__.__name__ in ['ConViT', 'NormFreeNet']: + if model.__class__.__name__ in ["ConViT", "NormFreeNet"]: return gm = symbolic_trace(model, meta_args=meta_args) @@ -39,8 +39,9 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): for key in transformed_fx_out.keys(): fx_output_val = transformed_fx_out[key] non_fx_output_val = transformed_non_fx_out[key] - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" # FIXME(ver217): timm/models/convit.py:71: in forward @@ -49,22 +50,22 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): # return self.tracer.to_bool(self) # torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow @pytest.mark.skip("convit is not supported yet") -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("1.12.0"), reason="torch version < 12") @clear_cache_before_run() def test_timm_models(): torch.backends.cudnn.deterministic = True - sub_model_zoo = model_zoo.get_sub_registry('timm') + sub_model_zoo = model_zoo.get_sub_registry("timm") for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} + meta_args = {k: v.to("meta") for k, v in data.items()} else: meta_args = None trace_and_compare(model_fn, data, output_transform_fn, meta_args) -if __name__ == '__main__': +if __name__ == "__main__": test_timm_models() diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index 2b7def5bef85..dd94a2546955 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -1,6 +1,5 @@ import pytest import torch -from packaging import version from torchaudio_utils import trace_and_compare from colossalai.testing import clear_cache_before_run @@ -14,11 +13,10 @@ def test_torchaudio_models(): torch.backends.cudnn.deterministic = True - sub_model_zoo = model_zoo.get_sub_registry('torchaudio') + sub_model_zoo = model_zoo.get_sub_registry("torchaudio") for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): model = model_fn() - trace_and_compare(model, - data_gen_fn, - output_transform_fn, - need_meta=(attribute is not None and attribute.has_control_flow)) + trace_and_compare( + model, data_gen_fn, output_transform_fn, need_meta=(attribute is not None and attribute.has_control_flow) + ) diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py index 239f38680cec..2379372bc3f9 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py @@ -6,7 +6,7 @@ def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False): data = data_gen() concrete_args = data if need_concrete else {} - meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {} + meta_args = {k: v.to("meta") for k, v in data.items()} if need_meta else {} model.eval() @@ -24,5 +24,6 @@ def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, nee for key, fx_output_val in transformed_fx_out.items(): non_fx_output_val = transformed_non_fx_out[key] - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index f969c8e6c3da..30c1910855e6 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -1,4 +1,3 @@ -import pytest import torch from colossalai._analyzer.fx import symbolic_trace @@ -32,31 +31,34 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): assert len(transformed_fx_out) == len(transformed_non_fx_out) if torch.is_tensor(fx_out): assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out, non_fx_out + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" else: assert torch.allclose( - fx_out.values(), - non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out.values(), non_fx_out.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" for key in transformed_fx_out.keys(): fx_output_val = transformed_fx_out[key] non_fx_output_val = transformed_non_fx_out[key] if torch.is_tensor(fx_output_val): - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" else: - assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() - ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + assert torch.allclose( + fx_output_val.values(), non_fx_output_val.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" @clear_cache_before_run() def test_torchrec_deepfm_models(): - deepfm_models = model_zoo.get_sub_registry('deepfm') + deepfm_models = model_zoo.get_sub_registry("deepfm") torch.backends.cudnn.deterministic = True for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} + meta_args = {k: v.to("meta") for k, v in data.items()} else: meta_args = None diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 94fb24f33376..71b73236474f 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -1,4 +1,3 @@ -import pytest import torch from colossalai._analyzer.fx import symbolic_trace @@ -32,37 +31,40 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): assert len(transformed_fx_out) == len(transformed_non_fx_out) if torch.is_tensor(fx_out): assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out, non_fx_out + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" else: assert torch.allclose( - fx_out.values(), - non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + fx_out.values(), non_fx_out.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" for key in transformed_fx_out.keys(): fx_output_val = transformed_fx_out[key] non_fx_output_val = transformed_non_fx_out[key] if torch.is_tensor(fx_output_val): - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + assert torch.allclose( + fx_output_val, non_fx_output_val, atol=1e-5 + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}" else: - assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() - ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + assert torch.allclose( + fx_output_val.values(), non_fx_output_val.values() + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}" @clear_cache_before_run() def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True - dlrm_models = model_zoo.get_sub_registry('dlrm') + dlrm_models = model_zoo.get_sub_registry("dlrm") for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items(): data = data_gen_fn() # dlrm_interactionarch is not supported # TODO(FrankLeeeee): support this model - if name == 'dlrm_interactionarch': + if name == "dlrm_interactionarch": continue if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} + meta_args = {k: v.to("meta") for k, v in data.items()} else: meta_args = None diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 74cb753e2937..47c6b1186c8e 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -8,7 +8,7 @@ @clear_cache_before_run() def test_torchvision_models(): torch.backends.cudnn.deterministic = True - tv_sub_registry = model_zoo.get_sub_registry('torchvision') + tv_sub_registry = model_zoo.get_sub_registry("torchvision") for name, (model_fn, data_gen_fn, output_transform_fn, _, model_attribute) in tv_sub_registry.items(): data = data_gen_fn() @@ -36,11 +36,11 @@ def test_torchvision_models(): fx_val = transformed_out[key] non_fx_val = transformed_non_fx_out[key] assert torch.allclose( - fx_val, - non_fx_val), f'{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}' + fx_val, non_fx_val + ), f"{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}" except Exception as e: print(name, e) -if __name__ == '__main__': +if __name__ == "__main__": test_torchvision_models() diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py index 3d56cc3484a6..2ddc8b6e68e4 100644 --- a/tests/test_infer/_utils.py +++ b/tests/test_infer/_utils.py @@ -1,20 +1,6 @@ import copy -import torch -import torch.distributed as dist -from torch import Tensor -from torch import distributed as dist -from torch.distributed import ProcessGroup -from torch.nn import Module -from torch.optim import Adam, Optimizer - -from colossalai.booster import Booster -from colossalai.booster.plugin import HybridParallelPlugin -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer._utils import getattr_ -from colossalai.shardformer.policies.auto_policy import Policy -from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor def build_model( @@ -28,11 +14,13 @@ def build_model( org_model = model_fn() # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused, - inference_only=True) + shard_config = ShardConfig( + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused, + inference_only=True, + ) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 8ecabf69ecf3..5a5d341fc6ba 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,5 +1,3 @@ -import os - import pytest import torch from packaging import version @@ -16,22 +14,27 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 32 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -@parameterize('test_config', [{ - 'tp_size': TP_SIZE, -}]) +@parameterize( + "test_config", + [ + { + "tp_size": TP_SIZE, + } + ], +) def run(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom_for_causal_lm') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom_for_causal_lm") for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): orig_model = model_fn() orig_model = orig_model.half() data = data_gen_fn() - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(do_sample=False) @@ -42,7 +45,7 @@ def run(test_config): def check_bloom(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run() @@ -54,5 +57,5 @@ def test_bloom_infer(): spawn(check_bloom, TP_SIZE) -if __name__ == '__main__': +if __name__ == "__main__": test_bloom_infer() diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index cc3cdd2b501b..f24160820e71 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -2,14 +2,12 @@ import pytest import torch -import torch.nn as nn from packaging import version -from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import BloomConfig, BloomForCausalLM from transformers.tokenization_utils_base import BatchEncoding import colossalai from colossalai.inference.tensor_parallel import TPInferEngine -from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -19,12 +17,17 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 8 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") -@parameterize('test_config', [{ - 'tp_size': TP_SIZE, -}]) +@parameterize( + "test_config", + [ + { + "tp_size": TP_SIZE, + } + ], +) def run(test_config): model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4) model = BloomForCausalLM(model_config) @@ -32,8 +35,9 @@ def run(test_config): model.to(torch.cuda.current_device()) # 1. check TPInferEngine init and model optimization - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) assert infer_engine.cache_manager is not None @@ -41,13 +45,17 @@ def run(test_config): assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE # 2. check data preparation - input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970], - [80540, 15473, 3331, 11970], [80540, 15473]] + input_ids_list = [ + [80540, 15473, 3331, 11970, 90472, 361, 61335], + [80540, 15473, 3331, 11970], + [80540, 15473, 3331, 11970], + [80540, 15473], + ] batch_size = len(input_ids_list) max_seq_len = max(len(li) for li in input_ids_list) attention_mask = [[0] * max_seq_len for _ in range(batch_size)] for i, li in enumerate(input_ids_list): - attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))] + attention_mask[i][max_seq_len - len(li) :] = [1 for _ in range(len(li))] data = dict(input_ids=input_ids_list, attention_mask=attention_mask) inputs_batch_encoding = BatchEncoding(data=data) seq_lengths = [len(li) for li in input_ids_list] @@ -78,7 +86,7 @@ def run(test_config): def check_engine(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run() @@ -90,5 +98,5 @@ def test_engine(): spawn(check_engine, TP_SIZE) -if __name__ == '__main__': +if __name__ == "__main__": test_engine() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index f57c6956f817..f3e2cdf1e18f 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -1,7 +1,8 @@ import os -from packaging import version + import pytest import torch +from packaging import version from colossalai.inference.tensor_parallel import MemoryManager from colossalai.logging import disable_existing_loggers @@ -14,14 +15,15 @@ HEAD_NUM = 32 HEAD_DIM = 128 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = str(port) + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) disable_existing_loggers() size = batch_size * (input_len + output_len) @@ -41,21 +43,24 @@ def create_cache_manager(rank, world_size, port, batch_size, input_len, output_l assert torch.equal(prefill_locs, prefill_locs_contiguous) assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill kvcache_manager.alloc_contiguous(batch_size) - assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False) + assert torch.all(kvcache_manager.mem_state[: total_token_prefill + batch_size] == False) + @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() def test_cache_manager_dist(): - spawn(create_cache_manager, - 4, - batch_size=BATCH_SIZE, - input_len=INPUT_LEN, - output_len=OUTPUT_LEN, - layer_num=LAYER_NUM, - head_num=HEAD_NUM, - head_dim=HEAD_DIM) + spawn( + create_cache_manager, + 4, + batch_size=BATCH_SIZE, + input_len=INPUT_LEN, + output_len=OUTPUT_LEN, + layer_num=LAYER_NUM, + head_num=HEAD_NUM, + head_dim=HEAD_DIM, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_cache_manager_dist() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index aa8874ea4cb0..0e5efe68508a 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,5 +1,4 @@ import os -import warnings import pytest import torch @@ -12,13 +11,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 BATCH_SIZE = 8 MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 -CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") def init_to_get_rotary(self, base=10000): @@ -34,8 +33,9 @@ def init_to_get_rotary(self, base=10000): else: max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / - self.config.head_dim_)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_) + ) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -44,20 +44,25 @@ def init_to_get_rotary(self, base=10000): return -@parameterize('test_config', [{ - 'tp_size': TPSIZE, -}]) +@parameterize( + "test_config", + [ + { + "tp_size": TPSIZE, + } + ], +) def run_llama_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama_for_casual_lm') + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm") for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): orig_model = model_fn() init_to_get_rotary(orig_model.model, base=10000) orig_model = orig_model.half() data = data_gen_fn() - shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False, - inference_only=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(do_sample=False) @@ -68,7 +73,7 @@ def run_llama_test(test_config): def check_llama(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_test() diff --git a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py index cb12faf6276c..a4d893f8e830 100644 --- a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py +++ b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py @@ -1,16 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import os import pytest -import numpy as np -from packaging import version - import torch from torch import nn -from torch.nn import functional as F -try: +try: from vllm import layernorm_ops + rms_norm = layernorm_ops.rms_norm HAS_VLLM_KERNERL = True except: @@ -18,6 +14,7 @@ print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") HAS_VLLM_KERNERL = False + class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -34,6 +31,7 @@ def forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): x = hidden_states out = torch.empty_like(x) @@ -45,6 +43,7 @@ def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): ) return out + @pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") def test_rmsnorm(): data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") @@ -56,5 +55,6 @@ def test_rmsnorm(): check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" + if __name__ == "__main__": - test_rmsnorm() \ No newline at end of file + test_rmsnorm() diff --git a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py index 2a85566c65c6..40451ef6636d 100644 --- a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py +++ b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py @@ -1,8 +1,8 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import pytest from typing import Tuple +import pytest import torch import torch.nn as nn import torch.nn.functional as F @@ -10,17 +10,18 @@ try: from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox HAS_VLLM_KERNERL = True -except: +except: print("fall back to original rotary_embedding_neox of huggingface") print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") HAS_VLLM_KERNERL = False def rotate_half(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -49,7 +50,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings # Create cos and sin embeddings. - inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) t = torch.arange(max_position_embeddings).float() freqs = torch.einsum("i,j->ij", t, inv_freq.float()) emb = torch.cat((freqs, freqs), dim=-1) @@ -64,11 +65,10 @@ def forward( query: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size] ) -> Tuple[torch.Tensor, torch.Tensor]: - - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] query_rot = query_rot.transpose(0, 1) key_rot = key_rot.transpose(0, 1) @@ -84,6 +84,7 @@ def forward( # Output query/key shape: [num_tokens, num_tokens, head_size] return query, key + def run_rotary_embedding_neox( num_tokens: int, num_heads: int, @@ -93,24 +94,18 @@ def run_rotary_embedding_neox( dtype: torch.dtype, base: int = 10000, ) -> None: - positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') - query = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device='cuda') - key = torch.randn(num_tokens, - num_heads * head_size, - dtype=dtype, - device='cuda') + positions = torch.randint(0, max_position, (num_tokens,), device="cuda") + query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") + key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") # Create the rotary embedding. - inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) t = torch.arange(max_position).float() - freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) cos = freqs.cos() sin = freqs.sin() cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") # Run the kernel. The kernel is in-place, so we need to clone the inputs. out_query = query.clone() @@ -128,7 +123,7 @@ def run_rotary_embedding_neox( dim=rotary_dim, max_position_embeddings=max_position, base=base, - ).to(dtype=dtype, device='cuda') + ).to(dtype=dtype, device="cuda") ref_query, ref_key = ref_rotary_embedding( positions, query.view(num_tokens, num_heads, head_size), @@ -141,6 +136,7 @@ def run_rotary_embedding_neox( assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) + @pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") def test_rotary_embedding(): run_rotary_embedding_neox( @@ -149,8 +145,9 @@ def test_rotary_embedding(): head_size=64, max_position=8192, rotary_dim=64, - dtype=torch.float16, + dtype=torch.float16, ) + if __name__ == "__main__": - test_rotary_embedding() \ No newline at end of file + test_rotary_embedding() diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py index b081b32b9ad3..0732ace1e04b 100644 --- a/tests/test_infer_ops/triton/kernel_utils.py +++ b/tests/test_infer_ops/triton/kernel_utils.py @@ -1,19 +1,18 @@ import math -import numpy as np import torch from torch.nn import functional as F def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): - ''' - adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 - ''' + """ + adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + """ xq = xq.view(bs, seqlen, num_head, head_dim) xk = xk.view(bs, seqlen, num_head, head_dim) xv = xv.view(bs, seqlen, num_head, head_dim) mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.] = -100000000.0 + mask[mask == 0.0] = -100000000.0 mask = mask.repeat(bs, num_head, 1, 1) keys = xk values = xv diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py index 344ad078e2e2..7a6c218a6691 100644 --- a/tests/test_infer_ops/triton/test_bloom_context_attention.py +++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py @@ -1,27 +1,24 @@ -import math - import pytest import torch from packaging import version -from torch import nn -from torch.nn import functional as F try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton import bloom_context_attn_fwd from tests.test_infer_ops.triton.kernel_utils import torch_context_attention + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_bloom_context_attention(): bs = 4 head_num = 8 @@ -46,8 +43,9 @@ def test_bloom_context_attention(): torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, - atol=1e-2), "outputs from triton and torch are not matched" + assert torch.allclose( + torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2 + ), "outputs from triton and torch are not matched" if __name__ == "__main__": diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py index c656f81d2790..34e453f7840e 100644 --- a/tests/test_infer_ops/triton/test_copy_kv_dest.py +++ b/tests/test_infer_ops/triton/test_copy_kv_dest.py @@ -1,25 +1,24 @@ import pytest import torch from packaging import version -from torch import nn try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_kv_cache_copy_op(): - B_NTX = 32 * 2048 head_num = 8 head_dim = 64 @@ -31,8 +30,9 @@ def test_kv_cache_copy_op(): copy_kv_cache_to_dest(cache, dest_index, dest_data) - assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, - atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched" + assert torch.allclose( + cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3 + ), "copy_kv_cache_to_dest outputs from triton and torch are not matched" if __name__ == "__main__": diff --git a/tests/test_infer_ops/triton/test_layernorm_triton.py b/tests/test_infer_ops/triton/test_layernorm_triton.py index 94cd704ffeba..7f814e8c9a9f 100644 --- a/tests/test_infer_ops/triton/test_layernorm_triton.py +++ b/tests/test_infer_ops/triton/test_layernorm_triton.py @@ -6,30 +6,29 @@ from colossalai.testing.utils import parameterize try: - import triton - import triton.language as tl + pass - from colossalai.kernel.triton.fused_layernorm import _layer_norm_fwd_fused HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") -@parameterize('M', [2, 4, 8, 16]) -@parameterize('N', [64, 128]) +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +@parameterize("M", [2, 4, 8, 16]) +@parameterize("N", [64, 128]) def test_layer_norm(M, N): dtype = torch.float16 eps = 1e-5 x_shape = (M, N) w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device='cuda') - bias = torch.rand(w_shape, dtype=dtype, device='cuda') - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + bias = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") y_triton = layer_norm(x, weight, bias, eps) y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index 4ea6095d4109..be6de6db2471 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -1,27 +1,24 @@ -import math - import pytest import torch from packaging import version -from torch import nn -from torch.nn import functional as F try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton import llama_context_attn_fwd from tests.test_infer_ops.triton.kernel_utils import torch_context_attention + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_llama_context_attention(): bs = 4 head_num = 8 @@ -45,8 +42,9 @@ def test_llama_context_attention(): torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, - atol=1e-3), "outputs from triton and torch are not matched" + assert torch.allclose( + torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3 + ), "outputs from triton and torch are not matched" if __name__ == "__main__": diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index d5ecdf684538..7e05ccafbfc4 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -1,14 +1,12 @@ # Adapted from ModelTC https://github.com/ModelTC/lightllm -import time import pytest import torch from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd @@ -17,13 +15,13 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") def torch_rotary_emb(x, cos, sin): seq_len, h, dim = x.shape - x0 = x[:, :, 0:dim // 2] - x1 = x[:, :, dim // 2:dim] + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] cos = cos.view((seq_len, 1, dim // 2)) sin = sin.view((seq_len, 1, dim // 2)) o0 = x0 * cos - x1 * sin @@ -31,8 +29,9 @@ def torch_rotary_emb(x, cos, sin): return torch.cat((o0, o1), dim=-1) -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_rotary_emb(): SEQ_LEN = 1 HEAD_NUM = 32 @@ -40,10 +39,10 @@ def test_rotary_emb(): dtype = torch.half # create data x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") cos_shape = (SEQ_LEN, HEAD_DIM // 2) - cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') - sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") # forward pass y_torch = torch_rotary_emb(x, cos, sin) rotary_embedding_fwd(x, cos, sin) diff --git a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py index 9692737a05a0..9bdec86645b2 100644 --- a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py +++ b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py @@ -1,24 +1,27 @@ import pytest -from packaging import version import torch -from torch import nn import torch.nn.functional as F +from packaging import version try: import triton - import triton.language as tl - from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton + from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel + from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_qkv_matmul(): - qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) + qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) scale = 1.2 head_size = 32 batches = qkv.shape[0] @@ -26,7 +29,7 @@ def test_qkv_matmul(): num_of_heads = d_model // head_size q = qkv[:, :, :d_model] - k = qkv[:, :, d_model:d_model * 2] + k = qkv[:, :, d_model : d_model * 2] q = q.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size) @@ -36,29 +39,40 @@ def test_qkv_matmul(): k = torch.transpose(k, 1, 2).contiguous() k = torch.transpose(k, 2, 3).contiguous() - torch_ouput = torch.einsum('bnij,bnjk->bnik', q, k) + torch_ouput = torch.einsum("bnij,bnjk->bnik", q, k) torch_ouput *= 1.2 q, k = q_copy, k_copy batches, M, H, K = q.shape N = k.shape[1] - score_output = torch.empty( - (batches, H, M, N), device=q.device, dtype=q.dtype) + score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) grid = lambda meta: ( batches, H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * - triton.cdiv(N, meta["BLOCK_SIZE_N"]), + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) K = q.shape[3] qkv_gemm_4d_kernel[grid]( - q, k, score_output, - M, N, K, - q.stride(0), q.stride(2), q.stride(1), q.stride(3), - k.stride(0), k.stride(2), k.stride(3), k.stride(1), - score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + q, + k, + score_output, + M, + N, + K, + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + k.stride(0), + k.stride(2), + k.stride(3), + k.stride(1), + score_output.stride(0), + score_output.stride(1), + score_output.stride(2), + score_output.stride(3), scale=scale, # currently manually setting, later on we can use auto-tune config to match best setting BLOCK_SIZE_M=64, @@ -69,21 +83,16 @@ def test_qkv_matmul(): check = torch.allclose(torch_ouput.cpu(), score_output.cpu(), rtol=1e-3, atol=1e-5) assert check is True, "the outputs of triton and torch are not matched" - -def self_attention_compute_using_torch(qkv, - input_mask, - scale, - head_size - ): +def self_attention_compute_using_torch(qkv, input_mask, scale, head_size): batches = qkv.shape[0] d_model = qkv.shape[-1] // 3 num_of_heads = d_model // head_size - + q = qkv[:, :, :d_model] - k = qkv[:, :, d_model:d_model * 2] - v = qkv[:, :, d_model * 2:] + k = qkv[:, :, d_model : d_model * 2] + v = qkv[:, :, d_model * 2 :] q = q.view(batches, -1, num_of_heads, head_size) k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) @@ -94,37 +103,36 @@ def self_attention_compute_using_torch(qkv, k = torch.transpose(k, -1, -2).contiguous() - score_output = torch.einsum('bnij,bnjk->bnik', q, k) + score_output = torch.einsum("bnij,bnjk->bnik", q, k) score_output *= scale - softmax_output = F.softmax(score_output, dim = -1) - res = torch.einsum('bnij,bnjk->bnik', softmax_output, v) + softmax_output = F.softmax(score_output, dim=-1) + res = torch.einsum("bnij,bnjk->bnik", softmax_output, v) res = torch.transpose(res, 1, 2) res = res.contiguous() - return res.view(batches, -1, d_model), score_output, softmax_output -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") -def test_self_atttention_test(): - qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_self_atttention_test(): + qkv = torch.randn((4, 24, 64 * 3), device="cuda", dtype=torch.float16) data_output_torch, score_output_torch, softmax_output_torch = self_attention_compute_using_torch( - qkv.clone(), - input_mask = None, - scale = 1.2, - head_size = 32 - ) + qkv.clone(), input_mask=None, scale=1.2, head_size=32 + ) data_output_triton = self_attention_compute_using_triton( - qkv.clone(), - alibi=None, - head_size=32, - scale=1.2, - input_mask=None, - layer_past=None, - use_flash=False, - triangular=True) + qkv.clone(), + alibi=None, + head_size=32, + scale=1.2, + input_mask=None, + layer_past=None, + use_flash=False, + triangular=True, + ) check = torch.allclose(data_output_triton.cpu(), data_output_torch.cpu(), rtol=1e-4, atol=1e-2) assert check is True, "the triton output is not matched with torch output" @@ -132,4 +140,4 @@ def test_self_atttention_test(): if __name__ == "__main__": test_qkv_matmul() - test_self_atttention_test() \ No newline at end of file + test_self_atttention_test() diff --git a/tests/test_infer_ops/triton/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py index 6a244608c43f..43b9c0929c4a 100644 --- a/tests/test_infer_ops/triton/test_softmax.py +++ b/tests/test_infer_ops/triton/test_softmax.py @@ -1,30 +1,31 @@ import pytest -from packaging import version import torch +from packaging import version from torch import nn - try: - import triton - import triton.language as tl from colossalai.kernel.triton.softmax import softmax + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_softmax_op(): data_samples = [ - torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), - torch.randn((320, 320, 78), device = "cuda", dtype = torch.float32), - torch.randn((2345, 4, 5, 64), device = "cuda", dtype = torch.float16) - ] + torch.randn((3, 4, 5, 32), device="cuda", dtype=torch.float32), + torch.randn((320, 320, 78), device="cuda", dtype=torch.float32), + torch.randn((2345, 4, 5, 64), device="cuda", dtype=torch.float16), + ] for data in data_samples: - module = nn.Softmax(dim = -1) + module = nn.Softmax(dim=-1) data_torch_out = module(data) data_triton_out = softmax(data) check = torch.allclose(data_torch_out.cpu(), data_triton_out.cpu(), rtol=1e-3, atol=1e-3) @@ -32,4 +33,4 @@ def test_softmax_op(): if __name__ == "__main__": - test_softmax_op() \ No newline at end of file + test_softmax_op() diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py index aee7944597dc..fc5f8cd6c9dc 100644 --- a/tests/test_infer_ops/triton/test_token_attn_1.py +++ b/tests/test_infer_ops/triton/test_token_attn_1.py @@ -5,16 +5,16 @@ from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): @@ -23,8 +23,9 @@ def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): keys = xk xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) - scores = (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape( - num_head, -1) + scores = ( + (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) + ) return scores @@ -37,10 +38,11 @@ def torch_attn_1(xq, xk, seqlen, num_head, head_dim): return logics -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_attn_1(): - import time + pass batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py index f834fedbb0f1..2dd756f2ba91 100644 --- a/tests/test_infer_ops/triton/test_token_attn_2.py +++ b/tests/test_infer_ops/triton/test_token_attn_2.py @@ -1,20 +1,18 @@ -import math - import pytest import torch from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") def torch_attn(V, P, bs, seqlen, num_head, head_dim): @@ -25,19 +23,23 @@ def torch_attn(V, P, bs, seqlen, num_head, head_dim): return attn_out -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_token_attn_2(): - import time + pass batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 dtype = torch.float16 V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) - Prob = torch.empty( - (head_num, batch_size * seq_len), dtype=dtype, - device="cuda").normal_(mean=0.4, std=0.2).reshape(head_num, batch_size, - seq_len).softmax(-1).reshape(head_num, batch_size * seq_len) + Prob = ( + torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") + .normal_(mean=0.4, std=0.2) + .reshape(head_num, batch_size, seq_len) + .softmax(-1) + .reshape(head_num, batch_size * seq_len) + ) attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda") kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index e82318965e05..9c7a53798317 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -1,20 +1,18 @@ -import time - import pytest import torch from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): @@ -29,10 +27,10 @@ def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): return torch.sum(prob * xv, dim=1, keepdim=False) -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test(): - Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128 dtype = torch.float16 q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) diff --git a/tests/test_infer_ops/triton/test_token_softmax.py b/tests/test_infer_ops/triton/test_token_softmax.py index 08ffe1ca8323..1f97f1674818 100644 --- a/tests/test_infer_ops/triton/test_token_softmax.py +++ b/tests/test_infer_ops/triton/test_token_softmax.py @@ -3,22 +3,22 @@ from packaging import version try: - import triton - import triton.language as tl + pass from colossalai.kernel.triton.token_attention_kernel import token_attn_softmax_fwd + HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) def test_softmax(): - import torch batch_size, seq_len, head_num, head_dim = 4, 1025, 12, 128 diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 9d9e9a3a5c76..ea6b16b94785 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -12,7 +12,7 @@ from colossalai.tensor.d_tensor.layout import Layout from tests.kit.model_zoo.registry import ModelAttribute -SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0') +SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse("1.12.0") # model_fn, data_gen_fn, output_transform_fn, model_attr TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]] @@ -28,18 +28,22 @@ def assert_model_equal(m1: torch.nn.Module, m2: torch.nn.Module) -> None: s1 = m1.state_dict() s2 = m2.state_dict() - assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}' + assert len(s1) == len(s2), f"len {len(s1)} vs {len(s2)}" for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()): assert n1 == n2 - assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' + assert torch.equal(t1, t2), f"{n1} {t1} vs {t2}" for p1, p2 in zip(m1.parameters(), m2.parameters()): assert p1.requires_grad == p2.requires_grad -def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict], - output_transform_fn: Callable[[Any], dict]) -> None: +def assert_forward_equal( + m1: torch.nn.Module, + m2: torch.nn.Module, + data_gen_fn: Callable[[], dict], + output_transform_fn: Callable[[Any], dict], +) -> None: data = data_gen_fn() m1.eval() @@ -57,15 +61,14 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: for key, out1 in transformed_out1.items(): out2 = transformed_out2[key] - assert torch.allclose(out1, out2, atol=1e-5), \ - f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}' + assert torch.allclose( + out1, out2, atol=1e-5 + ), f"{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}" -def check_lazy_init(entry: TestingEntry, - seed: int = 42, - verbose: bool = False, - check_forward: bool = False, - default_device: str = 'cpu') -> None: +def check_lazy_init( + entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False, default_device: str = "cpu" +) -> None: model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry _MyTensor._pre_op_fn = lambda *args: set_seed(seed) LazyTensor._pre_op_fn = lambda *args: set_seed(seed) @@ -84,15 +87,16 @@ def check_lazy_init(entry: TestingEntry, assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn) assert_forward_equal(deferred_model, copied_deferred_model, data_gen_fn, output_transform_fn) if verbose: - print(f'{model.__class__.__name__} pass') + print(f"{model.__class__.__name__} pass") -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, - sharding_spec_dict: dict) -> None: +def assert_dist_model_equal( + model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, sharding_spec_dict: dict +) -> None: state = model.state_dict() distributed_state = distributed_model.state_dict() - assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}' + assert len(state) == len(distributed_state), f"len {len(state)} vs {len(distributed_state)}" for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()): assert n1 == n2 @@ -102,4 +106,4 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn. layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape) t2.dist_layout = layout t2 = to_global(t2) - assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' + assert torch.equal(t1, t2), f"{n1} {t1} vs {t2}" diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index 18a737fcec85..978cf06b55a0 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -4,19 +4,21 @@ from tests.kit.model_zoo import model_zoo -@pytest.mark.skipif(not SUPPORT_LAZY, reason='requires torch >= 1.12.0') -@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) -@pytest.mark.parametrize('default_device', ['cpu', 'cuda']) +@pytest.mark.skipif(not SUPPORT_LAZY, reason="requires torch >= 1.12.0") +@pytest.mark.parametrize("subset", ["torchvision", "diffusers", "timm", "transformers", "torchaudio", "deepfm", "dlrm"]) +@pytest.mark.parametrize("default_device", ["cpu", "cuda"]) def test_torchvision_models_lazy_init(subset, default_device): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', - 'torchaudio_hubert_base') or name.startswith('transformers_llama') or name.startswith( - ('transformers_vit', 'transformers_blip2')): + if ( + name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") + or name.startswith("transformers_llama") + or name.startswith(("transformers_vit", "transformers_blip2")) + ): continue check_lazy_init(entry, verbose=True, default_device=default_device) -if __name__ == '__main__': - test_torchvision_models_lazy_init('torchvision') +if __name__ == "__main__": + test_torchvision_models_lazy_init("torchvision") diff --git a/tests/test_legacy/test_amp/test_naive_fp16.py b/tests/test_legacy/test_amp/test_naive_fp16.py index 54bf6498549c..76f9ff07407f 100644 --- a/tests/test_legacy/test_amp/test_naive_fp16.py +++ b/tests/test_legacy/test_amp/test_naive_fp16.py @@ -13,7 +13,7 @@ def check_equal(a, b): """ This function checks if two tensors are equal within tolerance """ - assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' + assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f"a = {a}, b = {b}" def run_naive_amp(): @@ -25,7 +25,7 @@ def run_naive_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ['repeated_computed_layers', 'nested_model', 'resnet18'] + test_models = ["repeated_computed_layers", "nested_model", "resnet18"] for test_name in test_models: get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, train_dataloader, _, optim_class, _ = get_component_func() @@ -41,9 +41,10 @@ def run_naive_amp(): # inject naive and apex amp naive_amp_config = dict(initial_scale=128, clip_grad_norm=1.0) - naive_amp_model, naive_amp_optimizer = convert_to_naive_amp(naive_amp_model, naive_amp_optimizer, - naive_amp_config) - apex_amp_config = dict(opt_level='O2', loss_scale=128, keep_batchnorm_fp32=False) + naive_amp_model, naive_amp_optimizer = convert_to_naive_amp( + naive_amp_model, naive_amp_optimizer, naive_amp_config + ) + apex_amp_config = dict(opt_level="O2", loss_scale=128, keep_batchnorm_fp32=False) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data @@ -78,7 +79,7 @@ def run_naive_amp(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") run_naive_amp() @@ -89,5 +90,5 @@ def test_naive_amp(): spawn(run_dist, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_naive_amp() diff --git a/tests/test_legacy/test_amp/test_torch_fp16.py b/tests/test_legacy/test_amp/test_torch_fp16.py index 89810b5d0351..47b303745e4e 100644 --- a/tests/test_legacy/test_amp/test_torch_fp16.py +++ b/tests/test_legacy/test_amp/test_torch_fp16.py @@ -18,7 +18,7 @@ def run_torch_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ['resnet18', 'simple_net'] + test_models = ["resnet18", "simple_net"] for test_name in test_models: get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, train_dataloader, _, optim_class, _ = get_component_func() @@ -34,10 +34,10 @@ def run_torch_amp(): # inject torch and apex amp torch_amp_config = dict(init_scale=128, enabled=True) - torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp(torch_amp_model, - torch_amp_optimizer, - amp_config=torch_amp_config) - apex_amp_config = dict(opt_level='O1', loss_scale=128) + torch_amp_model, torch_amp_optimizer, _ = convert_to_torch_amp( + torch_amp_model, torch_amp_optimizer, amp_config=torch_amp_config + ) + apex_amp_config = dict(opt_level="O1", loss_scale=128) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data @@ -61,7 +61,7 @@ def run_torch_amp(): # check grad # In apex amp, grad is not scaled before backward, but torch amp does for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): - assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config['loss_scale']) + assert_close_loose(torch_amp_param.grad, apex_amp_param.grad * apex_amp_config["loss_scale"]) # clip gradient apex_amp_optimizer.clip_grad_norm(model=apex_amp_model, max_norm=1.0) @@ -78,7 +78,7 @@ def run_torch_amp(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.legacy.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") run_torch_amp() @@ -89,5 +89,5 @@ def test_torch_amp(): spawn(run_dist, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_torch_amp() diff --git a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py index 4851b3e36bbc..bc243631a6c5 100644 --- a/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py +++ b/tests/test_legacy/test_comm/test_boardcast_send_recv_v2.py @@ -16,11 +16,15 @@ def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl', verbose=False) + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl", verbose=False) rank = gpc.get_local_rank(ParallelMode.PIPELINE) if rank == 0: - obj = [torch.randn(3,)] + obj = [ + torch.randn( + 3, + ) + ] _send_object(obj, 1) if rank == 1: @@ -30,7 +34,11 @@ def check_layer(rank, world_size, port): _recv_object(3) if rank == 3: - obj = [torch.randn(3,)] + obj = [ + torch.randn( + 3, + ) + ] _send_object(obj, 2) gpc.destroy() @@ -43,5 +51,5 @@ def test_object_list_p2p(): spawn(check_layer, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_object_list_p2p() diff --git a/tests/test_legacy/test_comm/test_comm.py b/tests/test_legacy/test_comm/test_comm.py index fccfcd973000..7d2c81972e5a 100644 --- a/tests/test_legacy/test_comm/test_comm.py +++ b/tests/test_legacy/test_comm/test_comm.py @@ -17,41 +17,41 @@ def check_all_gather(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) tensor = tensor.to(get_current_device()) - print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True) - print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) op.wait() - print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor)) torch.cuda.synchronize() def check_reduce_scatter(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) tensor = tensor.to(get_current_device()) - print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True) - print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) op.wait() - print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor)) torch.cuda.synchronize() def check_all_reduce(): tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)]) tensor = tensor.to(get_current_device()) - print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Before: Rank {0} - {1}".format(dist.get_rank(), tensor)) tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True) - print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("After: Rank {0} - {1}".format(dist.get_rank(), tensor)) op.wait() - print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor)) + print("Complete: Rank {0} - {1}".format(dist.get_rank(), tensor)) torch.cuda.synchronize() def check_layer(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") assert dist.get_rank() == gpc.get_global_rank() - print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size())) + print("Rank {} / {}".format(dist.get_rank(), dist.get_world_size())) check_all_gather() check_reduce_scatter() @@ -67,5 +67,5 @@ def test_comm(): spawn(check_layer, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_comm() diff --git a/tests/test_legacy/test_comm/test_object_list_p2p.py b/tests/test_legacy/test_comm/test_object_list_p2p.py index a1322e6f28db..69c68c7159e4 100644 --- a/tests/test_legacy/test_comm/test_object_list_p2p.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p.py @@ -27,7 +27,7 @@ def check_send_recv_forward(): if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") data_to_send = data.to(device) data_list_to_send = [] for data_in_list in data_list: @@ -35,7 +35,7 @@ def check_send_recv_forward(): send_forward(data_to_send) send_forward(data_list_to_send) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") data_recv = recv_forward(TENSOR_SIZE) data_list_recv = recv_forward(TENSOR_SIZE_LIST) data_to_check = data.to(device) @@ -47,7 +47,7 @@ def check_send_recv_forward(): def check_send_recv_backward(): if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") grad_recv = recv_backward(TENSOR_SIZE) grad_list_recv = recv_backward(TENSOR_SIZE_LIST) grad_to_check = grad.to(device) @@ -56,7 +56,7 @@ def check_send_recv_backward(): grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") grad_to_send = grad.to(device) grad_list_to_send = [] for grad_in_list in grad_list: @@ -67,7 +67,7 @@ def check_send_recv_backward(): def check_send_recv_forward_backward(): if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") data_list_to_send = [] for data_in_list in data_list: data_list_to_send.append(data_in_list.to(device)) @@ -77,7 +77,7 @@ def check_send_recv_forward_backward(): grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") grad_list_to_send = [] for grad_in_list in grad_list: grad_list_to_send.append(grad_in_list.to(device)) @@ -88,7 +88,7 @@ def check_send_recv_forward_backward(): def check_layer(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_send_recv_forward() check_send_recv_backward() check_send_recv_forward_backward() @@ -102,5 +102,5 @@ def test_object_list_p2p(): spawn(check_layer, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_object_list_p2p() diff --git a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py index f805bd19d7e8..eb05ea4839c6 100644 --- a/tests/test_legacy/test_comm/test_object_list_p2p_v2.py +++ b/tests/test_legacy/test_comm/test_object_list_p2p_v2.py @@ -32,7 +32,7 @@ def check_send_recv_forward(): local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if local_rank == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") data_to_send = data.to(device) data_list_to_send = [] for data_in_list in data_list: @@ -42,7 +42,7 @@ def check_send_recv_forward(): send_forward(data_list_to_send, scatter_gather_tensors=use_scatter_gather_tensors) elif local_rank == 1: - device = torch.device('cuda:1') + device = torch.device("cuda:1") data_recv = recv_forward(TENSOR_SIZE, scatter_gather_tensors=use_scatter_gather_tensors) data_list_recv = recv_forward(TENSOR_SIZE_LIST, scatter_gather_tensors=use_scatter_gather_tensors) @@ -60,7 +60,7 @@ def check_send_recv_forward(): def check_send_recv_backward(): disable_existing_loggers() if gpc.get_local_rank(ParallelMode.PIPELINE) == 0: - device = torch.device('cuda:0') + device = torch.device("cuda:0") grad_recv = recv_backward(TENSOR_SIZE) grad_list_recv = recv_backward(TENSOR_SIZE_LIST) @@ -73,7 +73,7 @@ def check_send_recv_backward(): grad_to_check = grad_send.to(device) assert grad_recv.equal(grad_to_check) else: - device = torch.device('cuda:1') + device = torch.device("cuda:1") grad_to_send = grad.to(device) grad_list_to_send = [] for grad_in_list in grad_list: @@ -104,7 +104,7 @@ def check_small_pipeline(): def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") disable_existing_loggers() # check_send_recv_forward() @@ -120,6 +120,6 @@ def test_object_list_p2p(): spawn(check_layer, world_size) -if __name__ == '__main__': +if __name__ == "__main__": disable_existing_loggers() test_object_list_p2p() diff --git a/tests/test_legacy/test_context/configs/parallel_2d_init.py b/tests/test_legacy/test_context/configs/parallel_2d_init.py index 6cf816942fdd..d1203fcdc436 100644 --- a/tests/test_legacy/test_context/configs/parallel_2d_init.py +++ b/tests/test_legacy/test_context/configs/parallel_2d_init.py @@ -1,4 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -parallel = dict(pipeline=dict(size=2), tensor=dict(size=4, mode='2d')) +parallel = dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")) diff --git a/tests/test_legacy/test_context/configs/parallel_2p5d_init.py b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py index b946d45b3a91..89e8cd6039f7 100644 --- a/tests/test_legacy/test_context/configs/parallel_2p5d_init.py +++ b/tests/test_legacy/test_context/configs/parallel_2p5d_init.py @@ -1,4 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, depth=2, mode='2.5d')) +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, depth=2, mode="2.5d")) diff --git a/tests/test_legacy/test_context/configs/parallel_3d_init.py b/tests/test_legacy/test_context/configs/parallel_3d_init.py index a1564bbb2d51..f9aa52fa4199 100644 --- a/tests/test_legacy/test_context/configs/parallel_3d_init.py +++ b/tests/test_legacy/test_context/configs/parallel_3d_init.py @@ -1,4 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, mode='3d')) +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, mode="3d")) diff --git a/tests/test_legacy/test_context/test_hybrid_parallel.py b/tests/test_legacy/test_context/test_hybrid_parallel.py index 05cd1d294dcd..b9e44bb34362 100644 --- a/tests/test_legacy/test_context/test_hybrid_parallel.py +++ b/tests/test_legacy/test_context/test_hybrid_parallel.py @@ -3,7 +3,6 @@ from pathlib import Path -import pytest import torch from colossalai.legacy import launch @@ -13,7 +12,7 @@ from colossalai.legacy.global_variables import tensor_parallel_env as tp_env from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn -CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) +CONFIG_PATH_LIST = list(Path(__file__).parent.glob("configs/*.py")) def check_data_parallel_rank(rank): @@ -50,11 +49,11 @@ def check_model_parallel_rank(rank): def check_tensor_parallel_rank(rank): - if tp_env.mode == '2d': + if tp_env.mode == "2d": check_2d_tensor_parallel_rank(rank) - elif tp_env == '2.5d': + elif tp_env == "2.5d": check_2p5d_tensor_parallel_rank(rank) - elif tp_env == '3d': + elif tp_env == "3d": check_3d_tensor_parallel_rank(rank) @@ -115,13 +114,9 @@ def check_3d_tensor_parallel_rank(rank): def init_context(config_path, rank, world_size, backend, port, host): - dist_args = dict(config=config_path, - rank=rank, - world_size=world_size, - backend=backend, - port=port, - host=host, - verbose=True) + dist_args = dict( + config=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host, verbose=True + ) launch(**dist_args) check_tensor_parallel_rank(rank) @@ -134,12 +129,9 @@ def init_context(config_path, rank, world_size, backend, port, host): def run_dist(rank, world_size, port, backend, port_list, host): for config_path, current_port in zip(CONFIG_PATH_LIST, port_list): - init_context(config_path=config_path, - rank=rank, - world_size=world_size, - backend=backend, - port=current_port, - host=host) + init_context( + config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=current_port, host=host + ) reset_seeds() @@ -158,8 +150,8 @@ def test_context(): port_list.append(port) break - spawn(run_dist, world_size, backend='gloo', port_list=port_list, host='localhost') + spawn(run_dist, world_size, backend="gloo", port_list=port_list, host="localhost") -if __name__ == '__main__': +if __name__ == "__main__": test_context() diff --git a/tests/test_legacy/test_data/test_cifar10_dataset.py b/tests/test_legacy/test_data/test_cifar10_dataset.py index dfa9fa211ef0..4851f1b85817 100644 --- a/tests/test_legacy/test_data/test_cifar10_dataset.py +++ b/tests/test_legacy/test_data/test_cifar10_dataset.py @@ -4,7 +4,6 @@ import os from pathlib import Path -import pytest from torch.utils.data import DataLoader from torchvision import datasets, transforms @@ -15,7 +14,7 @@ def test_cifar10_dataset(): transform_pipeline = transforms.Compose(transform_pipeline) # build dataset - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline) # build dataloader dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2) @@ -23,5 +22,5 @@ def test_cifar10_dataset(): img, label = data_iter.next() -if __name__ == '__main__': +if __name__ == "__main__": test_cifar10_dataset() diff --git a/tests/test_legacy/test_data/test_data_parallel_sampler.py b/tests/test_legacy/test_data/test_data_parallel_sampler.py index cf10fe9dfa3c..1786b4a77a8b 100644 --- a/tests/test_legacy/test_data/test_data_parallel_sampler.py +++ b/tests/test_legacy/test_data/test_data_parallel_sampler.py @@ -4,7 +4,6 @@ import os from pathlib import Path -import pytest import torch import torch.distributed as dist from torchvision import datasets, transforms @@ -16,24 +15,26 @@ from colossalai.legacy.utils import get_dataloader from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = Config(dict( - parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=1, mode=None), - ), - seed=1024, -)) +CONFIG = Config( + dict( + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=1, mode=None), + ), + seed=1024, + ) +) def run_data_sampler(rank, world_size, port): - dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend="gloo", port=port, host="localhost") colossalai.legacy.launch(**dist_args) - print('finished initialization') + print("finished initialization") # build dataset transform_pipeline = [transforms.ToTensor()] transform_pipeline = transforms.Compose(transform_pipeline) - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline) # build dataloader dataloader = get_dataloader(dataset, batch_size=8, add_sampler=True) @@ -50,7 +51,8 @@ def run_data_sampler(rank, world_size, port): if gpc.get_local_rank(ParallelMode.DATA) != 0: assert not torch.equal( - img, img_to_compare), 'Same image was distributed across ranks but expected it to be different' + img, img_to_compare + ), "Same image was distributed across ranks but expected it to be different" torch.cuda.empty_cache() @@ -59,5 +61,5 @@ def test_data_sampler(): spawn(run_data_sampler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_data_sampler() diff --git a/tests/test_legacy/test_data/test_deterministic_dataloader.py b/tests/test_legacy/test_data/test_deterministic_dataloader.py index 421b8d255318..abb442f48203 100644 --- a/tests/test_legacy/test_data/test_deterministic_dataloader.py +++ b/tests/test_legacy/test_data/test_deterministic_dataloader.py @@ -4,7 +4,6 @@ import os from pathlib import Path -import pytest import torch import torch.distributed as dist from torchvision import datasets, transforms @@ -20,8 +19,8 @@ dict( train_data=dict( dataset=dict( - type='CIFAR10', - root=Path(os.environ['DATA']), + type="CIFAR10", + root=Path(os.environ["DATA"]), train=True, download=True, ), @@ -32,17 +31,18 @@ tensor=dict(size=1, mode=None), ), seed=1024, - )) + ) +) def run_data_sampler(rank, world_size, port): - dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend='gloo', port=port, host='localhost') + dist_args = dict(config=CONFIG, rank=rank, world_size=world_size, backend="gloo", port=port, host="localhost") colossalai.legacy.launch(**dist_args) # build dataset transform_pipeline = [transforms.ToTensor(), transforms.RandomCrop(size=32, padding=4)] transform_pipeline = transforms.Compose(transform_pipeline) - dataset = datasets.CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_pipeline) + dataset = datasets.CIFAR10(root=Path(os.environ["DATA"]), train=True, download=True, transform=transform_pipeline) # build dataloader dataloader = get_dataloader(dataset, batch_size=8, add_sampler=False) @@ -60,8 +60,9 @@ def run_data_sampler(rank, world_size, port): if gpc.get_local_rank(ParallelMode.DATA) != 0: # this is without sampler # this should be false if data parallel sampler to given to the dataloader - assert torch.equal(img, - img_to_compare), 'Same image was distributed across ranks and expected it to be the same' + assert torch.equal( + img, img_to_compare + ), "Same image was distributed across ranks and expected it to be the same" torch.cuda.empty_cache() @@ -70,5 +71,5 @@ def test_data_sampler(): spawn(run_data_sampler, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_data_sampler() diff --git a/tests/test_legacy/test_engine/test_engine.py b/tests/test_legacy/test_engine/test_engine.py index 8499784038d2..b07fe8abe86e 100644 --- a/tests/test_legacy/test_engine/test_engine.py +++ b/tests/test_legacy/test_engine/test_engine.py @@ -6,25 +6,26 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), - fp16=dict(mode=None), - clip_grad_norm=1.0) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0 +) -@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers']) -@parameterize('amp_mode', [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) +@parameterize("model_name", ["repeated_computed_layers", "resnet18", "repeated_computed_layers"]) +@parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) def run_train(model_name, amp_mode): # FIXME: test bert get_components_func = non_distributed_component_funcs.get_callable(model_name) - gpc.config.fp16['mode'] = amp_mode + gpc.config.fp16["mode"] = amp_mode model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model = model_builder(checkpoint=False) - engine, train_dataloader, *args = colossalai.legacy.initialize(model=model, - optimizer=optimizer_class(model.parameters(), - lr=1e-3), - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *args = colossalai.legacy.initialize( + model=model, + optimizer=optimizer_class(model.parameters(), lr=1e-3), + criterion=criterion, + train_dataloader=train_dataloader, + ) try: engine.train() @@ -49,12 +50,9 @@ def run_train(model_name, amp_mode): def run_engine(rank, world_size, port): # init dist env - colossalai.legacy.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) run_train() @@ -64,5 +62,5 @@ def test_engine(): spawn(run_engine, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_engine() diff --git a/tests/test_legacy/test_engine/test_gradient_accumluation.py b/tests/test_legacy/test_engine/test_gradient_accumluation.py index 168c93c1a572..262876e0ba42 100644 --- a/tests/test_legacy/test_engine/test_gradient_accumluation.py +++ b/tests/test_legacy/test_engine/test_gradient_accumluation.py @@ -19,46 +19,40 @@ BATCH_SIZE = 2 NUM_CLASSES = 10 -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), - clip_grad_norm=1.0, - gradient_accumulation=4) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), clip_grad_norm=1.0, gradient_accumulation=4 +) def run_no_pipeline(rank, world_size, port): - # init dist env - colossalai.legacy.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) # build model model = resnet18(num_classes=10) # build dataloaders - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - drop_last=True) + train_dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))] + ), + ) + train_dataloader = get_dataloader( + dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True + ) # build optimizer optimizer = Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() - engine, train_dataloader, *args = colossalai.legacy.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - logger = get_dist_logger() + engine, train_dataloader, *args = colossalai.legacy.initialize( + model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader + ) + get_dist_logger() rank = torch.distributed.get_rank() param_track = [] grad_track = [] @@ -79,12 +73,13 @@ def run_no_pipeline(rank, world_size, port): param_track.append(next(model.parameters())[0].clone()) grad_track.append(next(model.parameters()).grad[0].clone()) step += 1 - if step == CONFIG['gradient_accumulation']: + if step == CONFIG["gradient_accumulation"]: break - assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations' - assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \ - 'param should be the same in the first few iterations and only changed in the last iteration' + assert not torch.all(grad_track[0] == grad_track[-1]), "grad should be different in different iterations" + assert torch.all(param_track[0] == param_track[1]) and not torch.all( + param_track[0] == param_track[-1] + ), "param should be the same in the first few iterations and only changed in the last iteration" gpc.destroy() torch.cuda.empty_cache() @@ -96,5 +91,5 @@ def test_engine(): spawn(run_no_pipeline, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_engine() diff --git a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py index 859707e6129d..8a9a73d65f38 100644 --- a/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_legacy/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -44,7 +44,7 @@ def check_linear_col(): W = W.clone() W.requires_grad = True - B_shape = (OUTPUT_SIZE) + B_shape = OUTPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) dist.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=0)[i] @@ -65,7 +65,7 @@ def check_linear_col(): C = torch.chunk(C_master, DEPTH, dim=-1)[i] check_equal(out, C) - print_rank_0('linear_col forward: pass') + print_rank_0("linear_col forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -87,7 +87,7 @@ def check_linear_col(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, layer.bias.grad) - print_rank_0('linear_col backward: pass') + print_rank_0("linear_col backward: pass") def check_linear_row(): @@ -114,7 +114,7 @@ def check_linear_row(): W = W.clone() W.requires_grad = True - B_shape = (INPUT_SIZE) + B_shape = INPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) dist.broadcast(B_master, src=0) B = B_master.clone() @@ -134,7 +134,7 @@ def check_linear_row(): C = C_master.clone() check_equal(out, C) - print_rank_0('linear_row forward: pass') + print_rank_0("linear_row forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -155,7 +155,7 @@ def check_linear_row(): B_grad = B_master.grad check_equal(B_grad, layer.bias.grad) - print_rank_0('linear_row backward: pass') + print_rank_0("linear_row backward: pass") def check_embed(): @@ -184,7 +184,7 @@ def check_embed(): C_master = embed_master(A_master) C = C_master.clone() check_equal(out, C) - print_rank_0('embed forward: pass') + print_rank_0("embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -197,7 +197,7 @@ def check_embed(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') + print_rank_0("embed backward: pass") def check_vocab_parallel_embed(): @@ -226,7 +226,7 @@ def check_vocab_parallel_embed(): C_master = embed_master(A_master) C = C_master.clone() check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') + print_rank_0("vocab parallel embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -239,7 +239,7 @@ def check_vocab_parallel_embed(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') + print_rank_0("vocab parallel embed backward: pass") def check_classifier_no_given_weight(): @@ -283,7 +283,7 @@ def check_classifier_no_given_weight(): C = C_master.clone() check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') + print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -305,7 +305,7 @@ def check_classifier_no_given_weight(): B_grad = layer_master.bias.grad check_equal(B_grad, layer.bias.grad) - print_rank_0('classifier (no given weight) backward: pass') + print_rank_0("classifier (no given weight) backward: pass") def check_vocab_parallel_classifier_no_given_weight(): @@ -343,7 +343,7 @@ def check_vocab_parallel_classifier_no_given_weight(): C = torch.chunk(C_master, DEPTH, dim=-1)[i] check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') + print_rank_0("vocab parallel classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -365,7 +365,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, layer.bias.grad) - print_rank_0('vocab parallel classifier (no given weight) backward: pass') + print_rank_0("vocab parallel classifier (no given weight) backward: pass") def check_classifier_given_embed_weight(): @@ -401,7 +401,7 @@ def check_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = C_master.clone() check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') + print_rank_0("classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -416,7 +416,7 @@ def check_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('classifier (given embed weight) backward: pass') + print_rank_0("classifier (given embed weight) backward: pass") def check_vocab_parallel_classifier_given_embed_weight(): @@ -452,7 +452,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=-1)[i] check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + print_rank_0("vocab parallel classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -468,7 +468,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + print_rank_0("vocab parallel classifier (given embed weight) backward: pass") def check_vocab_parallel_loss(): @@ -495,7 +495,7 @@ def check_vocab_parallel_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('vocab parallel loss forward: pass') + print_rank_0("vocab parallel loss forward: pass") loss.backward() loss_master.backward() @@ -503,7 +503,7 @@ def check_vocab_parallel_loss(): out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] check_equal(out_grad, out.grad) - print_rank_0('vocab parallel loss backward: pass') + print_rank_0("vocab parallel loss backward: pass") @torch.no_grad() @@ -531,7 +531,7 @@ def check_linear_row_stream_inference(): W = torch.chunk(W_master, DEPTH, dim=-1)[i] W = W.clone() - B_shape = (INPUT_SIZE) + B_shape = INPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) dist.broadcast(B_master, src=0) B = B_master.clone() @@ -550,4 +550,4 @@ def check_linear_row_stream_inference(): C = C_master.clone() check_equal(out, C) - print_rank_0('linear_row forward: pass') + print_rank_0("linear_row forward: pass") diff --git a/tests/test_legacy/test_layers/test_1d/test_1d.py b/tests/test_legacy/test_layers/test_1d/test_1d.py index 2a016ed7b33d..cebbedd303ee 100644 --- a/tests/test_legacy/test_layers/test_1d/test_1d.py +++ b/tests/test_legacy/test_layers/test_1d/test_1d.py @@ -10,12 +10,14 @@ from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode="1d")), +) def check_layer(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_linear_col() check_linear_row() @@ -39,5 +41,5 @@ def test_1d(): spawn(check_layer, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_1d() diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py index 494497be33e2..0bbc72eca809 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -48,7 +48,7 @@ def check_linear(): W = W.clone() W.requires_grad = True - B_shape = (OUTPUT_SIZE) + B_shape = OUTPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, DEPTH, dim=-1)[j] @@ -71,7 +71,7 @@ def check_linear(): C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('linear forward: pass') + print_rank_0("linear forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -99,7 +99,7 @@ def check_linear(): # if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('linear backward: pass') + print_rank_0("linear backward: pass") def check_layernorm(): @@ -136,7 +136,7 @@ def check_layernorm(): C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('layer norm forward: pass') + print_rank_0("layer norm forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -150,7 +150,7 @@ def check_layernorm(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j] check_equal(A_grad, A.grad) - print_rank_0('layer norm backward: pass') + print_rank_0("layer norm backward: pass") def check_embed(): @@ -181,7 +181,7 @@ def check_embed(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('embed forward: pass') + print_rank_0("embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -197,7 +197,7 @@ def check_embed(): B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') + print_rank_0("embed backward: pass") def check_patch_embed(): @@ -238,7 +238,7 @@ def check_patch_embed(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('patch embed forward: pass') + print_rank_0("patch embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -270,7 +270,7 @@ def check_patch_embed(): bias_grad = torch.chunk(bias_grad, DEPTH)[j] bias_grad = torch.chunk(bias_grad, DEPTH)[i] check_equal(bias_grad, layer.bias.grad) - print_rank_0('patch embed backward: pass') + print_rank_0("patch embed backward: pass") def check_vocab_parallel_embed(): @@ -301,7 +301,7 @@ def check_vocab_parallel_embed(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') + print_rank_0("vocab parallel embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -317,7 +317,7 @@ def check_vocab_parallel_embed(): B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') + print_rank_0("vocab parallel embed backward: pass") def check_classifier_no_given_weight(): @@ -368,7 +368,7 @@ def check_classifier_no_given_weight(): # C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') + print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -395,7 +395,7 @@ def check_classifier_no_given_weight(): # if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('classifier (no given weight) backward: pass') + print_rank_0("classifier (no given weight) backward: pass") def check_vocab_parallel_classifier_no_given_weight(): @@ -437,7 +437,7 @@ def check_vocab_parallel_classifier_no_given_weight(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') + print_rank_0("vocab parallel classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -463,7 +463,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = torch.chunk(B_grad, DEPTH)[j] B_grad = torch.chunk(B_grad, DEPTH)[i] check_equal(B_grad, layer.bias.grad) - print_rank_0('vocab parallel classifier (no given weight) backward: pass') + print_rank_0("vocab parallel classifier (no given weight) backward: pass") def check_classifier_given_embed_weight(): @@ -499,7 +499,7 @@ def check_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=0)[i] check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') + print_rank_0("classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -515,7 +515,7 @@ def check_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('classifier (given embed weight) backward: pass') + print_rank_0("classifier (given embed weight) backward: pass") def check_vocab_parallel_classifier_given_embed_weight(): @@ -552,7 +552,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + print_rank_0("vocab parallel classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -569,14 +569,14 @@ def check_vocab_parallel_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j] W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + print_rank_0("vocab parallel classifier (given embed weight) backward: pass") def check_loss(): device = get_current_device() dtype = torch.float32 - j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) + gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW) i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL) criterion = CrossEntropyLoss2D() @@ -596,7 +596,7 @@ def check_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('cross entropy loss forward: pass') + print_rank_0("cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -604,7 +604,7 @@ def check_loss(): out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] check_equal(out_grad, out.grad) - print_rank_0('cross entropy loss backward: pass') + print_rank_0("cross entropy loss backward: pass") def check_vocab_parallel_loss(): @@ -632,7 +632,7 @@ def check_vocab_parallel_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('vocab parallel cross entropy loss forward: pass') + print_rank_0("vocab parallel cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -641,7 +641,7 @@ def check_vocab_parallel_loss(): out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j] check_equal(out_grad, out.grad) - print_rank_0('vocab parallel cross entropy loss backward: pass') + print_rank_0("vocab parallel cross entropy loss backward: pass") # def check_attention(): diff --git a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py index 034dbe5ca29c..9c126cefeba8 100644 --- a/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py +++ b/tests/test_legacy/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -14,10 +14,12 @@ def check_AB(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float @@ -42,10 +44,22 @@ def check_AB(): out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH) - out = Matmul_AB_2D.apply(A, B, DEPTH, out_shape, i, j, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, - data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + out = Matmul_AB_2D.apply( + A, + B, + DEPTH, + out_shape, + i, + j, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) A_master = A_master.clone() A_master.requires_grad = True B_master = B_master.clone() @@ -55,7 +69,7 @@ def check_AB(): C = torch.chunk(C, DEPTH, dim=-1)[j] # check forward correctness check_equal(out, C) - print_rank_0('AB forward: pass') + print_rank_0("AB forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -77,15 +91,17 @@ def check_AB(): B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] # check backward correctness check_equal(B_grad, B.grad) - print_rank_0('AB backward: pass') + print_rank_0("AB backward: pass") def check_ABT(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float @@ -110,11 +126,22 @@ def check_ABT(): B = B.clone() B.requires_grad = True - out = Matmul_ABT_2D.apply(C, B, DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), i, j, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, - pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + out = Matmul_ABT_2D.apply( + C, + B, + DEPTH, + (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), + i, + j, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) C_master = C_master.clone() C_master.requires_grad = True B_master = B_master.clone() @@ -123,7 +150,7 @@ def check_ABT(): A = torch.chunk(A_master, DEPTH, dim=0)[i] A = torch.chunk(A, DEPTH, dim=-1)[j] check_equal(out, A) - print_rank_0('ABT forward: pass') + print_rank_0("ABT forward: pass") grad_shape = A_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -144,15 +171,17 @@ def check_ABT(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] check_equal(B_grad, B.grad) - print_rank_0('ABT backward: pass') + print_rank_0("ABT backward: pass") def check_ATB(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) device = get_current_device() @@ -177,21 +206,33 @@ def check_ATB(): C = C.clone() C.requires_grad = True - out = Matmul_ATB_2D.apply(A, C, DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), i, j, - ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, - pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + out = Matmul_ATB_2D.apply( + A, + C, + DEPTH, + (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), + i, + j, + ParallelMode.PARALLEL_2D_ROW, + ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (HIDDEN_SIZE, 4 * HIDDEN_SIZE) A_master = A_master.clone() A_master.requires_grad = True C_master = C_master.clone() C_master.requires_grad = True B_master = torch.matmul( - A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]) + ) B = torch.chunk(B_master, DEPTH, dim=0)[i] B = torch.chunk(B, DEPTH, dim=-1)[j] check_equal(out, B) - print_rank_0('ATB forward: pass') + print_rank_0("ATB forward: pass") grad_shape = B_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -211,4 +252,4 @@ def check_ATB(): C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i] C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j] check_equal(C_grad, C.grad) - print_rank_0('ATB backward: pass') + print_rank_0("ATB backward: pass") diff --git a/tests/test_legacy/test_layers/test_2d/test_2d.py b/tests/test_legacy/test_layers/test_2d/test_2d.py index a4b46793f19d..77a4b281a746 100644 --- a/tests/test_legacy/test_layers/test_2d/test_2d.py +++ b/tests/test_legacy/test_layers/test_2d/test_2d.py @@ -23,7 +23,9 @@ from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),) +CONFIG = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode="2d")), +) def check_operations(): @@ -48,7 +50,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False @@ -65,5 +67,5 @@ def test_2d(): spawn(check_layer_and_operation, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_2d() diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index e7a9a8be45d0..283e7f68374f 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -30,7 +30,7 @@ def check_linear(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False) @@ -50,7 +50,7 @@ def check_linear(): W = W.clone() W.requires_grad = True - B_shape = (OUTPUT_SIZE) + B_shape = OUTPUT_SIZE B_master = torch.randn(B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] @@ -60,7 +60,7 @@ def check_linear(): layer.weight = Parameter(W) layer.bias = Parameter(B) out = layer(A) - bias = layer.bias + layer.bias A_master = A_master.clone() A_master.requires_grad = True @@ -73,7 +73,7 @@ def check_linear(): C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('linear forward: pass') + print_rank_0("linear forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -100,7 +100,7 @@ def check_linear(): if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('linear backward: pass') + print_rank_0("linear backward: pass") def check_layernorm(): @@ -111,7 +111,7 @@ def check_layernorm(): i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype) @@ -138,7 +138,7 @@ def check_layernorm(): C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('layer norm forward: pass') + print_rank_0("layer norm forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -152,7 +152,7 @@ def check_layernorm(): A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i] A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j] check_equal(A_grad, A.grad) - print_rank_0('layer norm backward: pass') + print_rank_0("layer norm backward: pass") def check_embed(): @@ -160,7 +160,7 @@ def check_embed(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -184,7 +184,7 @@ def check_embed(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('embed forward: pass') + print_rank_0("embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -200,7 +200,7 @@ def check_embed(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') + print_rank_0("embed backward: pass") def check_patch_embed(): @@ -208,7 +208,7 @@ def check_patch_embed(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype) torch.nn.init.ones_(layer.cls_token) @@ -242,7 +242,7 @@ def check_patch_embed(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('patch embed forward: pass') + print_rank_0("patch embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -274,7 +274,7 @@ def check_patch_embed(): bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j] bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i] check_equal(bias_grad, layer.bias.grad) - print_rank_0('patch embed backward: pass') + print_rank_0("patch embed backward: pass") def check_vocab_parallel_embed(): @@ -282,7 +282,7 @@ def check_vocab_parallel_embed(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -306,7 +306,7 @@ def check_vocab_parallel_embed(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') + print_rank_0("vocab parallel embed forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -322,7 +322,7 @@ def check_vocab_parallel_embed(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') + print_rank_0("vocab parallel embed backward: pass") def check_classifier_no_given_weight(): @@ -374,7 +374,7 @@ def check_classifier_no_given_weight(): # C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') + print_rank_0("classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -401,7 +401,7 @@ def check_classifier_no_given_weight(): # if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('classifier (no given weight) backward: pass') + print_rank_0("classifier (no given weight) backward: pass") def check_vocab_parallel_classifier_no_given_weight(): @@ -409,7 +409,7 @@ def check_vocab_parallel_classifier_no_given_weight(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) layer = layer.to(dtype).to(device) @@ -442,7 +442,7 @@ def check_vocab_parallel_classifier_no_given_weight(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') + print_rank_0("vocab parallel classifier (no given weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -468,7 +468,7 @@ def check_vocab_parallel_classifier_no_given_weight(): B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j] if i == 0: check_equal(B_grad, layer.bias.grad) - print_rank_0('vocab parallel classifier (no given weight) backward: pass') + print_rank_0("vocab parallel classifier (no given weight) backward: pass") def check_classifier_given_embed_weight(): @@ -476,7 +476,7 @@ def check_classifier_given_embed_weight(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -504,7 +504,7 @@ def check_classifier_given_embed_weight(): C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') + print_rank_0("classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -520,7 +520,7 @@ def check_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('classifier (given embed weight) backward: pass') + print_rank_0("classifier (given embed weight) backward: pass") def check_vocab_parallel_classifier_given_embed_weight(): @@ -528,7 +528,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE) embed = embed.to(dtype).to(device) @@ -557,7 +557,7 @@ def check_vocab_parallel_classifier_given_embed_weight(): C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i] C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + print_rank_0("vocab parallel classifier (given embed weight) forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -574,15 +574,15 @@ def check_vocab_parallel_classifier_given_embed_weight(): W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j] W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i] check_equal(W_grad, embed.weight.grad) - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + print_rank_0("vocab parallel classifier (given embed weight) backward: pass") def check_loss(): device = get_current_device() dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) - j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) criterion = CrossEntropyLoss2p5D() criterion_master = torch.nn.CrossEntropyLoss() @@ -601,7 +601,7 @@ def check_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('cross entropy loss forward: pass') + print_rank_0("cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -609,7 +609,7 @@ def check_loss(): out_grad = out_master.grad out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] check_equal(out_grad, out.grad) - print_rank_0('cross entropy loss backward: pass') + print_rank_0("cross entropy loss backward: pass") def check_vocab_parallel_loss(): @@ -617,7 +617,7 @@ def check_vocab_parallel_loss(): dtype = torch.float32 i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) - k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) + gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) criterion = VocabParallelCrossEntropyLoss2p5D() criterion_master = torch.nn.CrossEntropyLoss() @@ -637,7 +637,7 @@ def check_vocab_parallel_loss(): out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) check_equal(loss, loss_master) - print_rank_0('vocab parallel cross entropy loss forward: pass') + print_rank_0("vocab parallel cross entropy loss forward: pass") loss.backward() loss_master.backward() @@ -646,7 +646,7 @@ def check_vocab_parallel_loss(): out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i] out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j] check_equal(out_grad, out.grad) - print_rank_0('vocab parallel cross entropy loss backward: pass') + print_rank_0("vocab parallel cross entropy loss backward: pass") # def check_attention(): diff --git a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py index fe78ef669bf0..992bd6107f08 100644 --- a/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -11,10 +11,12 @@ def check_AB(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float @@ -39,11 +41,23 @@ def check_AB(): B.requires_grad = True out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM) - out = Matmul_AB_2p5D.apply(A, B, TESSERACT_DIM, out_shape, i, j, k, ParallelMode.PARALLEL_2P5D_ROW, - ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, pipeline_parallel_rank, - pipeline_parallel_size, tensor_parallel_size) - - C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) + out = Matmul_AB_2p5D.apply( + A, + B, + TESSERACT_DIM, + out_shape, + i, + j, + k, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) A_master = A_master.clone() A_master.requires_grad = True B_master = B_master.clone() @@ -53,7 +67,7 @@ def check_AB(): C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j] # check forward correctness check_equal(out, C) - print_rank_0('AB forward: pass') + print_rank_0("AB forward: pass") grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -75,15 +89,17 @@ def check_AB(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] # check backward correctness check_equal(B_grad, B.grad) - print_rank_0('AB backward: pass') + print_rank_0("AB backward: pass") def check_ABT(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) dtype = torch.float @@ -109,12 +125,23 @@ def check_ABT(): B = B.clone() B.requires_grad = True - out = Matmul_ABT_2p5D.apply(C, B, TESSERACT_DIM, - (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), i, j, k, - ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, - pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + out = Matmul_ABT_2p5D.apply( + C, + B, + TESSERACT_DIM, + (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), + i, + j, + k, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) C_master = C_master.clone() C_master.requires_grad = True B_master = B_master.clone() @@ -123,7 +150,7 @@ def check_ABT(): A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i] A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j] check_equal(out, A) - print_rank_0('ABT forward: pass') + print_rank_0("ABT forward: pass") grad_shape = A_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -144,15 +171,17 @@ def check_ABT(): B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i] B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j] check_equal(B_grad, B.grad) - print_rank_0('ABT backward: pass') + print_rank_0("ABT backward: pass") def check_ATB(): data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA) - pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank( - ParallelMode.PIPELINE) - pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( - ParallelMode.PIPELINE) + pipeline_parallel_rank = ( + 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + ) + pipeline_parallel_size = ( + 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE) + ) tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR) device = get_current_device() @@ -178,22 +207,34 @@ def check_ATB(): C = C.clone() C.requires_grad = True - out = Matmul_ATB_2p5D.apply(A, C, TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), - i, j, k, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, - data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, - tensor_parallel_size) - - B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) + out = Matmul_ATB_2p5D.apply( + A, + C, + TESSERACT_DIM, + (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), + i, + j, + k, + ParallelMode.PARALLEL_2P5D_ROW, + ParallelMode.PARALLEL_2P5D_COL, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + ) + + (HIDDEN_SIZE, 4 * HIDDEN_SIZE) A_master = A_master.clone() A_master.requires_grad = True C_master = C_master.clone() C_master.requires_grad = True B_master = torch.matmul( - A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]) + ) B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] check_equal(out, B) - print_rank_0('ATB forward: pass') + print_rank_0("ATB forward: pass") grad_shape = B_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=device) @@ -213,4 +254,4 @@ def check_ATB(): C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i] C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j] check_equal(C_grad, C.grad) - print_rank_0('ATB backward: pass') + print_rank_0("ATB backward: pass") diff --git a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py index 38ba3ba78575..437a8f8a7265 100644 --- a/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_legacy/test_layers/test_2p5d/test_2p5d.py @@ -8,10 +8,12 @@ from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict( - pipeline=dict(size=1), - tensor=dict(size=4, mode='2.5d', depth=1), -),) +CONFIG = dict( + parallel=dict( + pipeline=dict(size=1), + tensor=dict(size=4, mode="2.5d", depth=1), + ), +) def check_operations(): @@ -36,7 +38,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False @@ -53,5 +55,5 @@ def test_2p5d(): spawn(check_layer_and_operation, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_2p5d() diff --git a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py index 2a9dcc3cdc16..a4a4ae9a5ba4 100644 --- a/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_legacy/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -73,14 +73,15 @@ def check_linear(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'linear forward: {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "linear forward: {0} --> {1} | {2:.3f} s".format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger + ) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=0)[k] - logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=get_current_device()) @@ -93,24 +94,24 @@ def check_linear(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("linear backward: {:.3f} s".format(bwd_end - bwd_start), logger) C_master.backward(grad_master) A_grad = A_master.grad A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} linear backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + logger.info("Rank {} linear backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad))) B_grad = layer_master.weight.grad.transpose(0, 1) B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + logger.info("Rank {} linear backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad))) bias_grad = layer_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[j] - logger.info('Rank {} linear backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad))) + logger.info("Rank {} linear backward (bias_grad): {}".format(rank, check_equal(bias_grad, layer.bias.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -157,8 +158,11 @@ def check_layernorm(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), logger) + "layer norm forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() A_master.requires_grad = True @@ -166,7 +170,7 @@ def check_layernorm(): C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} layernorm forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -179,22 +183,22 @@ def check_layernorm(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("layer norm backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) C_master.backward(grad_master) A_grad = A_master.grad A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} layernorm backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad))) + logger.info("Rank {} layernorm backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad))) bias_grad = norm_master.weight.grad bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} layernorm backward (weight_grad): {}'.format(rank, check_equal(bias_grad, norm.weight.grad))) + logger.info("Rank {} layernorm backward (weight_grad): {}".format(rank, check_equal(bias_grad, norm.weight.grad))) bias_grad = norm_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} layernorm backward (bias_grad): {}'.format(rank, check_equal(bias_grad, norm.bias.grad))) + logger.info("Rank {} layernorm backward (bias_grad): {}".format(rank, check_equal(bias_grad, norm.bias.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -241,14 +245,17 @@ def check_classifier_no_given_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=get_current_device()) @@ -261,7 +268,7 @@ def check_classifier_no_given_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("classifier (no given weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -269,21 +276,29 @@ def check_classifier_no_given_weight(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format( - rank, check_equal(A_grad, A.grad))) + logger.info( + "Rank {} classifier (no given weight) backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad)) + ) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] if j == k: - logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( - rank, check_equal(B_grad, layer.weight.grad))) + logger.info( + "Rank {} classifier (no given weight) backward (weight_grad): {}".format( + rank, check_equal(B_grad, layer.weight.grad) + ) + ) else: - logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format( - rank, layer.weight.grad is None)) + logger.info( + "Rank {} classifier (no given weight) backward (weight_grad): {}".format(rank, layer.weight.grad is None) + ) bias_grad = layer_master.bias.grad - logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format( - rank, check_equal(bias_grad, layer.bias.grad))) + logger.info( + "Rank {} classifier (no given weight) backward (bias_grad): {}".format( + rank, check_equal(bias_grad, layer.bias.grad) + ) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -333,15 +348,18 @@ def check_vocab_parallel_classifier_no_given_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() A_master.requires_grad = True C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=0)[k] - logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} vocab parallel classifier (no given weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -355,8 +373,9 @@ def check_vocab_parallel_classifier_no_given_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), - logger) + print_rank_0( + "vocab parallel classifier (no given weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger + ) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -364,20 +383,29 @@ def check_vocab_parallel_classifier_no_given_weight(): A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i] A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k] A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j] - logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format( - rank, check_equal(A_grad, A.grad))) + logger.info( + "Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}".format( + rank, check_equal(A_grad, A.grad) + ) + ) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format( - rank, check_equal(B_grad, layer.weight.grad))) + logger.info( + "Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}".format( + rank, check_equal(B_grad, layer.weight.grad) + ) + ) bias_grad = layer_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[j] - logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format( - rank, check_equal(bias_grad, layer.bias.grad))) + logger.info( + "Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}".format( + rank, check_equal(bias_grad, layer.bias.grad) + ) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -423,13 +451,16 @@ def check_classifier_given_embed_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) @@ -442,7 +473,7 @@ def check_classifier_given_embed_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("classifier (given embed weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -450,11 +481,15 @@ def check_classifier_given_embed_weight(): B_grad = embed_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] if j == k: - logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( - rank, check_equal(B_grad, embed.weight.grad))) + logger.info( + "Rank {} classifier (given embed weight) backward (weight_grad): {}".format( + rank, check_equal(B_grad, embed.weight.grad) + ) + ) else: - logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format( - rank, embed.weight.grad is None)) + logger.info( + "Rank {} classifier (given embed weight) backward (weight_grad): {}".format(rank, embed.weight.grad is None) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -501,14 +536,17 @@ def check_vocab_parallel_classifier_given_embed_weight(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger) + "vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() C_master = layer_master(embed_master(A_master)) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[j] C = torch.chunk(C, DEPTH, dim=0)[k] - logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} vocab parallel classifier (given embed weight) forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -522,8 +560,9 @@ def check_vocab_parallel_classifier_given_embed_weight(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), - logger) + print_rank_0( + "vocab parallel classifier (given embed weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger + ) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -532,9 +571,9 @@ def check_vocab_parallel_classifier_given_embed_weight(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, - check_equal(B_grad, - embed.weight.grad))) + logger.info( + "Rank {} vocab parallel embed backward (weight_grad): {}".format(rank, check_equal(B_grad, embed.weight.grad)) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -543,7 +582,7 @@ def check_patch_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -582,15 +621,18 @@ def check_patch_embed(): torch.cuda.synchronize() fwd_end = time.time() print_rank_0( - 'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), logger) + "patch embed forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + logger, + ) A_master = A_master.clone() C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} patch embed forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -604,29 +646,32 @@ def check_patch_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger) + print_rank_0("patch embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger) grad_master = grad_master.clone() C_master.backward(grad_master) cls_grad_master = layer_master.cls_token.grad cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k] - logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad))) + logger.info("Rank {} patch embed backward (cls_grad): {}".format(rank, check_equal(cls_grad, layer.cls_token.grad))) pos_grad_master = layer_master.pos_embed.grad pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k] - logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank, - check_equal(pos_grad, layer.pos_embed.grad))) + logger.info( + "Rank {} patch embed backward (pos_embed_grad): {}".format(rank, check_equal(pos_grad, layer.pos_embed.grad)) + ) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] - logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank, - check_equal(B_grad, layer.weight.grad))) + logger.info( + "Rank {} patch embed backward (proj_weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad)) + ) bias_grad = layer_master.bias.grad bias_grad = torch.chunk(bias_grad, DEPTH)[k] - logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank, - check_equal(bias_grad, layer.bias.grad))) + logger.info( + "Rank {} patch embed backward (proj_bias_grad): {}".format(rank, check_equal(bias_grad, layer.bias.grad)) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -635,7 +680,7 @@ def check_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -664,16 +709,17 @@ def check_embed(): out = layer(A) torch.cuda.synchronize() fwd_end = time.time() - logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), - fwd_end - fwd_start), - ranks=[0]) + logger.info( + "embed forward: pass | {0} --> {1} | {2:.3f} s".format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), + ranks=[0], + ) A_master = A_master.clone() C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} embed forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -686,14 +732,14 @@ def check_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) grad_master = grad_master.clone() C_master.backward(grad_master) B_grad = layer_master.weight.grad B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) + logger.info("Rank {} embed backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -702,7 +748,7 @@ def check_vocab_parallel_embed(): rank = torch.distributed.get_rank() device = get_current_device() logger = get_dist_logger() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -733,16 +779,19 @@ def check_vocab_parallel_embed(): out = layer(A) torch.cuda.synchronize() fwd_end = time.time() - logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), - ranks=[0]) + logger.info( + "vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(A.shape), tuple(out.shape), fwd_end - fwd_start + ), + ranks=[0], + ) A_master = A_master.clone() C_master = layer_master(A_master) C = torch.chunk(C_master, DEPTH, dim=0)[i] C = torch.chunk(C, DEPTH, dim=-1)[k] C = torch.chunk(C, DEPTH, dim=0)[j] - logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C))) + logger.info("Rank {} vocab parallel embed forward: {}".format(rank, check_equal(out, C))) grad_shape = C_master.shape grad_master = torch.randn(grad_shape, device=device) @@ -755,7 +804,7 @@ def check_vocab_parallel_embed(): out.backward(grad) torch.cuda.synchronize() bwd_end = time.time() - logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("vocab parallel embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) grad_master = grad_master.clone() C_master.backward(grad_master) @@ -764,9 +813,9 @@ def check_vocab_parallel_embed(): B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k] - logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank, - check_equal(B_grad, - layer.weight.grad))) + logger.info( + "Rank {} vocab parallel embed backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad)) + ) return fwd_end - fwd_start, bwd_end - bwd_start @@ -798,25 +847,28 @@ def check_loss(): fwd_start = time.time() loss = criterion(out, target_master) fwd_end = time.time() - logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), - fwd_end - fwd_start), - ranks=[0]) + logger.info( + "cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start + ), + ranks=[0], + ) out_master = out_master.clone() out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) - logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + logger.info("Rank {} cross entropy loss forward: {}".format(rank, check_equal(loss, loss_master))) bwd_start = time.time() loss.backward() bwd_end = time.time() - logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("cross entropy loss backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) loss_master.backward() out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] - logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + logger.info("Rank {} cross entropy loss backward: {}".format(rank, check_equal(out_grad, out.grad))) return fwd_end - fwd_start, bwd_end - bwd_start @@ -825,7 +877,7 @@ def check_vocab_parallel_loss(): rank = torch.distributed.get_rank() logger = get_dist_logger() device = get_current_device() - dtype = torch.float32 + torch.float32 input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) @@ -852,25 +904,28 @@ def check_vocab_parallel_loss(): fwd_start = time.time() loss = criterion(out, target_master) fwd_end = time.time() - logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format( - tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), - ranks=[0]) + logger.info( + "vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s".format( + tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start + ), + ranks=[0], + ) out_master = out_master.clone() out_master.requires_grad = True loss_master = criterion_master(out_master, target_master) - logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master))) + logger.info("Rank {} vocab parallel cross entropy loss forward: {}".format(rank, check_equal(loss, loss_master))) bwd_start = time.time() loss.backward() bwd_end = time.time() - logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0]) + logger.info("vocab parallel cross entropy loss backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0]) loss_master.backward() out_grad = out_master.grad out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i] out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k] out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j] - logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad))) + logger.info("Rank {} vocab parallel cross entropy loss backward: {}".format(rank, check_equal(out_grad, out.grad))) return fwd_end - fwd_start, bwd_end - bwd_start diff --git a/tests/test_legacy/test_layers/test_3d/test_3d.py b/tests/test_legacy/test_layers/test_3d/test_3d.py index 2a32d8935c00..7057e2308b39 100644 --- a/tests/test_legacy/test_layers/test_3d/test_3d.py +++ b/tests/test_legacy/test_layers/test_3d/test_3d.py @@ -23,7 +23,7 @@ CONFIG = dict( parallel=dict( pipeline=1, - tensor=dict(mode='3d', size=8), + tensor=dict(mode="3d", size=8), ), seed=42, ) @@ -44,7 +44,7 @@ def check_layer(): def check_layer_and_operation(rank, world_size, port): disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.deterministic = True @@ -60,5 +60,5 @@ def test_3d(): spawn(check_layer_and_operation, 8) -if __name__ == '__main__': +if __name__ == "__main__": test_3d() diff --git a/tests/test_legacy/test_layers/test_cache_embedding.py b/tests/test_legacy/test_layers/test_cache_embedding.py index c58445a396ec..d64ff56b8a65 100644 --- a/tests/test_legacy/test_layers/test_cache_embedding.py +++ b/tests/test_legacy/test_layers/test_cache_embedding.py @@ -38,10 +38,19 @@ def synthesize_1d_sparse_feature( ): indices_in_batch = batch_size * 2 indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long) - offsets = torch.from_numpy( - np.array([ - 0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch - ])).to(device).long() + offsets = ( + torch.from_numpy( + np.array( + [ + 0, + *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), + indices_in_batch, + ] + ) + ) + .to(device) + .long() + ) return indices, offsets @@ -89,7 +98,7 @@ def test_reorder_with_freq(): chunkid.append(idx // chunk_size) offset_in_chunk.append(idx % chunk_size) - dev = torch.device('cuda') + dev = torch.device("cuda") chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev) offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) @@ -99,31 +108,31 @@ def test_reorder_with_freq(): mgr.reorder(idx_map) indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev)) - mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor') + mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode="floor") mgr_offsets = torch.remainder(indices, chunk_size) assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}" - assert torch.allclose(offset_in_chunk, mgr_offsets), \ - f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" + assert torch.allclose(offset_in_chunk, mgr_offsets), f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" @clear_cache_before_run() -@parameterize('use_LFU', [True, False]) +@parameterize("use_LFU", [True, False]) def test_freq_aware_embed(use_LFU: bool): - device = torch.device('cuda', 0) + device = torch.device("cuda", 0) evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET - model = CachedEmbeddingBag(NUM_EMBED, - EMBED_DIM, - mode='mean', - include_last_offset=True, - cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), - ids_freq_mapping=None, - evict_strategy=evict_strategy).to(device) + model = CachedEmbeddingBag( + NUM_EMBED, + EMBED_DIM, + mode="mean", + include_last_offset=True, + cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0), + ids_freq_mapping=None, + evict_strategy=evict_strategy, + ).to(device) assert model.weight.shape[0] == NUM_EMBED - ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), - mode='mean', - include_last_offset=True, - freeze=False) + ref_model = torch.nn.EmbeddingBag.from_pretrained( + model.weight.detach().to(device), mode="mean", include_last_offset=True, freeze=False + ) assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device)) @@ -149,22 +158,25 @@ def test_freq_aware_embed(use_LFU: bool): model.cache_weight_mgr.flush() model_weight = model.weight.detach().to(device) ref_weight = ref_model.weight.detach() - assert torch.allclose(model_weight, ref_weight), \ - f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" + assert torch.allclose( + model_weight, ref_weight + ), f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" @clear_cache_before_run() -@parameterize('init_freq', [True, False]) +@parameterize("init_freq", [True, False]) def test_lfu_strategy(init_freq: bool): # minimal test to check behavior - Bag = CachedEmbeddingBag(5, - 5, - cache_ratio=3 / 5, - buffer_size=0, - pin_weight=True, - ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, - warmup_ratio=1.0, - evict_strategy=EvictionStrategy.LFU) + Bag = CachedEmbeddingBag( + 5, + 5, + cache_ratio=3 / 5, + buffer_size=0, + pin_weight=True, + ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, + warmup_ratio=1.0, + evict_strategy=EvictionStrategy.LFU, + ) # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map) offsets = torch.tensor([0], device="cuda:0") @@ -189,14 +201,15 @@ def test_lfu_strategy(init_freq: bool): # check strategy Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) - Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 - Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit - Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 - Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit - Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 + Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit + Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit - assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ - "LFU strategy behavior failed" + assert torch.allclose( + torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1]) + ), "LFU strategy behavior failed" def gather_tensor(tensor, rank, world_size): @@ -211,7 +224,7 @@ def gather_tensor(tensor, rank, world_size): def run_parallel_freq_aware_embed_tablewise(rank, world_size): if world_size != 2: return - device = torch.device('cuda', torch.cuda.current_device()) + device = torch.device("cuda", torch.cuda.current_device()) # initialize weight # 3 feature tables. idx: 0~5, 6~10, 11~17 @@ -221,20 +234,20 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): weight_table3 = weight_tables[11:18] embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] embedding_bag_config_list.append( - TablewiseEmbeddingBagConfig(num_embeddings=6, - cuda_row_num=4, - assigned_rank=0, - initial_weight=weight_table1.clone().detach().cpu())) + TablewiseEmbeddingBagConfig( + num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu() + ) + ) embedding_bag_config_list.append( - TablewiseEmbeddingBagConfig(num_embeddings=5, - cuda_row_num=4, - assigned_rank=0, - initial_weight=weight_table2.clone().detach().cpu())) + TablewiseEmbeddingBagConfig( + num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu() + ) + ) embedding_bag_config_list.append( - TablewiseEmbeddingBagConfig(num_embeddings=7, - cuda_row_num=4, - assigned_rank=1, - initial_weight=weight_table3.clone().detach().cpu())) + TablewiseEmbeddingBagConfig( + num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu() + ) + ) if rank == 0: _weight = torch.cat([weight_table1, weight_table2], 0) else: @@ -249,7 +262,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): evict_strategy=EvictionStrategy.LFU, ) # explain - ''' + """ batch feature 1 feature 2 feature 3 input0 [1,2,3] [6,7] [] input1 [] [9] [13,15] @@ -257,10 +270,12 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): ↑ ↑ ↑ rank 0 rank 0 rank 1 in KJT format - ''' - res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), - torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), - already_split_along_rank=False) + """ + res = model( + torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), + already_split_along_rank=False, + ) optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) if rank == 0: @@ -273,13 +288,15 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): # check correctness if rank == 0: - ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(), - include_last_offset=True, - freeze=False).to(device) + ref_model = torch.nn.EmbeddingBag.from_pretrained( + weight_tables.detach().clone(), include_last_offset=True, freeze=False + ).to(device) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0) - ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), - torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) + ref_res = ref_model( + torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device), + ) ref_res.backward(ref_fake_grad) ref_optimizer.step() ref_optimizer.zero_grad() @@ -291,7 +308,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): def run_parallel_freq_aware_embed_columnwise(rank, world_size): - device = torch.device('cuda', torch.cuda.current_device()) + device = torch.device("cuda", torch.cuda.current_device()) num_embed = 100 embed_dim = 16 @@ -313,19 +330,20 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size): cache_ratio=batch_size * 2 / num_embed, ) - assert model.cache_weight_mgr.weight.device.type == 'cpu' + assert model.cache_weight_mgr.weight.device.type == "cpu" assert model.cache_weight_mgr.cuda_cached_weight.requires_grad weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank] print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}") - assert torch.allclose(weight_in_rank, - model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}" + assert torch.allclose( + weight_in_rank, model.cache_weight_mgr.weight.detach() + ), f"{weight_in_rank - model.cache_weight_mgr.weight}" optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) if rank == 0: - ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(), - include_last_offset=True, - freeze=False).to(device) + ref_model = torch.nn.EmbeddingBag.from_pretrained( + weight.detach().clone(), include_last_offset=True, freeze=False + ).to(device) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) set_seed(4321) @@ -360,19 +378,19 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # run_parallel_freq_aware_embed_columnwise(rank, world_size) run_parallel_freq_aware_embed_tablewise(rank, world_size) @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_parallel_freq_aware_embed(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": # test_freq_aware_embed(True) test_parallel_freq_aware_embed(2) # test_lfu_strategy(False) diff --git a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py index ac9493adab2e..aa4d5d6ceeb3 100644 --- a/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ b/tests/test_legacy/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -16,6 +16,7 @@ def check_selfattention(): layer = layer.to(get_current_device()) hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) - attention_mask = torch.randint(low=0, high=2, - size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device()) - out = layer(hidden_states, attention_mask) + attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to( + get_current_device() + ) + layer(hidden_states, attention_mask) diff --git a/tests/test_legacy/test_layers/test_sequence/test_sequence.py b/tests/test_legacy/test_layers/test_sequence/test_sequence.py index 85226f9d934a..bdd3e04c6479 100644 --- a/tests/test_legacy/test_layers/test_sequence/test_sequence.py +++ b/tests/test_legacy/test_layers/test_sequence/test_sequence.py @@ -8,7 +8,7 @@ from colossalai.legacy.nn.layer.parallel_sequence import RingAV, RingQK from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) +CONFIG = dict(parallel=dict(tensor=dict(size=4, mode="sequence"))) def check_ring_qk(rank, world_size): @@ -26,8 +26,8 @@ def check_ring_qk(rank, world_size): dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) # create distributed tensors - sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() - sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + sub_q = q.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() + sub_k = k.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() # set autograd attributes q.requires_grad = True @@ -47,7 +47,7 @@ def check_ring_qk(rank, world_size): sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length) # check master and distributed attention scores - sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + sub_master_a = a[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2) # run master backward @@ -55,13 +55,12 @@ def check_ring_qk(rank, world_size): a.mean().backward() # run distributed backward - partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + partial_master_a_grad = a.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] torch.autograd.backward(sub_a, partial_master_a_grad) # check master and distributed grads - partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \ - 'attention score cannot match' + partial_master_q_grad = q.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), "attention score cannot match" def check_ring_av(rank, world_size): @@ -79,8 +78,8 @@ def check_ring_av(rank, world_size): dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) # create distributed tensors - sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() - sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous() + sub_a = a.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() + sub_v = v.clone()[:, rank * sub_seq_length : (rank + 1) * sub_seq_length].contiguous() # set autograd attributes a.requires_grad = True @@ -102,7 +101,7 @@ def check_ring_av(rank, world_size): # print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}') # check master and distributed output - sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + sub_master_out = out[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2) # # run master backward @@ -110,17 +109,16 @@ def check_ring_av(rank, world_size): out.mean().backward() # # run distributed backward - partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] + partial_master_out_grad = out.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] torch.autograd.backward(sub_out, partial_master_out_grad) # # check master and distributed grads - partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length] - assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \ - 'attention output cannot match' + partial_master_a_grad = a.grad[:, rank * sub_seq_length : (rank + 1) * sub_seq_length] + assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), "attention output cannot match" def run_test(rank, world_size, port): - colossalai.legacy.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port) + colossalai.legacy.launch(rank=rank, world_size=world_size, config=CONFIG, host="localhost", port=port) # check_ring_qk(rank, world_size) check_ring_av(rank, world_size) @@ -135,5 +133,5 @@ def test_sequence(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_sequence() diff --git a/tests/test_legacy/test_pipeline/rpc_test_utils.py b/tests/test_legacy/test_pipeline/rpc_test_utils.py index 9a336c4224be..e59f22062cfc 100644 --- a/tests/test_legacy/test_pipeline/rpc_test_utils.py +++ b/tests/test_legacy/test_pipeline/rpc_test_utils.py @@ -3,12 +3,10 @@ import warnings import torch -import torch.distributed as dist import torch.distributed.rpc as rpc import torch.multiprocessing as mp from torch import nn from torch._C._distributed_rpc import _is_current_rpc_agent_set -from torch.optim import SGD, Adam, Optimizer, RMSprop from colossalai.legacy import launch from colossalai.legacy.pipeline.pipeline_process_group import ppg @@ -17,13 +15,12 @@ rpc_is_initialized = _is_current_rpc_agent_set -def color_debug(text, prefix=' ', color='blue'): +def color_debug(text, prefix=" ", color="blue"): color = color.upper() print(getattr(Back, color), prefix, Style.RESET_ALL, text) class MLP(nn.Module): - def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -38,7 +35,6 @@ def forward(self, x): class DAG_MLP(nn.Module): - def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -55,12 +51,11 @@ def forward(self, x, y): class RpcTestModel(nn.Module): - def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: super().__init__() self.rank = stage_id self.is_last_rank = stage_id == actual_stage_num - 1 - self.linear_name = f'linear_{stage_id}' + self.linear_name = f"linear_{stage_id}" if stage_id == 0: linear = nn.Linear(feat_num, h) @@ -82,38 +77,38 @@ def forward(self, x) -> torch.Tensor: def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--epoch', type=int, default=1) - parser.add_argument('--world_size', type=int, default=2) - parser.add_argument('--batch_size', type=int, default=16) - parser.add_argument('--dp_degree', type=int, default=1) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--num_microbatches', type=int, default=2) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--use_checkpoint', action='store_true') - parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'RMSprop'], default='SGD') - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') - parser.add_argument('--num_worker_threads', type=str, default=128) + parser.add_argument("--epoch", type=int, default=1) + parser.add_argument("--world_size", type=int, default=2) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--num_microbatches", type=int, default=2) + parser.add_argument("--chunk", type=int, default=1) + parser.add_argument("--use_checkpoint", action="store_true") + parser.add_argument("--optimizer", type=str, choices=["SGD", "Adam", "RMSprop"], default="SGD") + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29020") + parser.add_argument("--num_worker_threads", type=str, default=128) return parser.parse_args() def pg_parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--world_size', type=int, default=4) - parser.add_argument('--dp_degree', type=int, default=2) - parser.add_argument('--tp_degree', type=int, default=1) - parser.add_argument('--chunk', type=int, default=1) - parser.add_argument('--num_worker_threads', type=str, default=128) - parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') - parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument("--world_size", type=int, default=4) + parser.add_argument("--dp_degree", type=int, default=2) + parser.add_argument("--tp_degree", type=int, default=1) + parser.add_argument("--chunk", type=int, default=1) + parser.add_argument("--num_worker_threads", type=str, default=128) + parser.add_argument("--device", type=str, choices=["cpu", "cuda"], default="cuda") + parser.add_argument("--master_addr", type=str, default="localhost") + parser.add_argument("--master_port", type=str, default="29020") return parser.parse_args() def run_worker(rank, args, master_func): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port device = args.device world_size = args.world_size @@ -122,17 +117,19 @@ def run_worker(rank, args, master_func): num_worker_threads = args.num_worker_threads host = args.master_addr port = args.master_port - backend = 'nccl' if device == 'cuda' else 'gloo' + backend = "nccl" if device == "cuda" else "gloo" disable_existing_loggers() launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) + ppg.set_global_info( + rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device, + ) # in rpc mode, only rank 0 is needed to be coded if rank == 0: diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py index 3bff08318d40..f6c077136607 100644 --- a/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_chimera.py @@ -4,7 +4,6 @@ from torch import nn from colossalai.legacy.pipeline.rpc import ChimeraPipelineEngine -from colossalai.testing import assert_close # global variable for model created feat_num = 100 @@ -20,7 +19,7 @@ def partition(pp_rank: int, chunk: int, stage_num: int): def run_master(args): torch.manual_seed(100) - epoch = args.epoch + args.epoch device = args.device stage_num = args.world_size chunk = 1 @@ -32,11 +31,13 @@ def run_master(args): assert sample_num % batch_size == 0 - engine = ChimeraPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - checkpoint=use_checkpoint) + engine = ChimeraPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + checkpoint=use_checkpoint, + ) engine.initialize_optimizer(torch.optim.Adam, lr=1e-3) input_sample = torch.randn((sample_num, feat_num), device=device) @@ -56,7 +57,8 @@ def run_master(args): # compute forward result and backward grad of parameters just in rank_0 test_model = nn.Sequential( - *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)] + ).to(device) # input_sample = input_sample[len(input_sample) // 2:] input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py index eff031ff8faa..806f24a64511 100644 --- a/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_optimizer.py @@ -1,9 +1,9 @@ import torch from rpc_test_utils import RpcTestModel, parse_args, rpc_run from torch import autograd, nn -from torch.optim import SGD, Adam, Optimizer, RMSprop +from torch.optim import Optimizer -from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.testing import assert_close # global variable for model created @@ -36,12 +36,14 @@ def run_master(args): input_sample = torch.randn((sample_num, feat_num), device=device) - engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint) + engine = OneFOneBPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) engine.initialize_optimizer(optimizer_class, lr=lr) @@ -59,7 +61,8 @@ def run_master(args): # compute forward result and backward grad of parameters just in rank_0 test_model = nn.Sequential( - *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)] + ).to(device) optimizer: Optimizer = optimizer_class(test_model.parameters(), lr=lr) input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py index 1a6077f8d3e9..a5e8fc6e6b51 100644 --- a/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_pipeline.py @@ -1,8 +1,7 @@ import torch from rpc_test_utils import RpcTestModel, parse_args, rpc_run -from torch import nn -from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine # global variable for model created feat_num = 100 @@ -32,12 +31,14 @@ def run_master(args): input_sample = torch.randn((sample_num, feat_num), device=device) - engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint) + engine = OneFOneBPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) for _ in range(epoch): _ = engine.forward_backward(input_sample, forward_only=False) diff --git a/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py index 43966ce3dbda..09c9b84a9907 100644 --- a/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py +++ b/tests/test_legacy/test_pipeline/test_cuda_rpc_value_correctness.py @@ -2,7 +2,7 @@ from rpc_test_utils import RpcTestModel, parse_args, rpc_run from torch import autograd, nn -from colossalai.legacy.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine +from colossalai.legacy.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.testing import assert_close feat_num = 100 @@ -32,12 +32,14 @@ def run_master(args): input_sample = torch.randn((sample_num, feat_num), device=device) - engine = OneFOneBPipelineEngine(partition_fn=partition, - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint) + engine = OneFOneBPipelineEngine( + partition_fn=partition, + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) forward_result = engine.forward_backward(input_sample) @@ -54,7 +56,8 @@ def run_master(args): # compute forward result and backward grad of parameters just in rank_0 test_model = nn.Sequential( - *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)]).to(device) + *[partition(pp_rank, chunk, actual_stage_num) for pp_rank in range(actual_stage_num)] + ).to(device) input_sample = input_sample.requires_grad_() out_val = test_model(input_sample).sum() autograd.backward(out_val) diff --git a/tests/test_legacy/test_pipeline/test_middleware_1f1b.py b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py index 4e43d52f8aee..dff04c3ebba1 100644 --- a/tests/test_legacy/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_legacy/test_pipeline/test_middleware_1f1b.py @@ -25,7 +25,7 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): model.eval() tracer = ColoTracer() - meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} + meta_args = {k: v.to("meta") for k, v in data_kwargs.items()} graph = tracer.trace(root=model, meta_args=meta_args) gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) annotated_model = balanced_split_pass(gm, stage_num) @@ -33,7 +33,7 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): - setattr(submodule, '_topo', topo) + setattr(submodule, "_topo", topo) return split_submodules[pp_rank + 1] @@ -47,11 +47,11 @@ def run_master(model_cls, world_size, forward_only): torch.manual_seed(100) epoch = 3 - device = 'cuda' + device = "cuda" stage_num = world_size chunk = 1 num_microbatches = 8 - use_checkpoint = 'store_true' + use_checkpoint = "store_true" if model_cls == MLP: @@ -92,29 +92,26 @@ def data_gen(): checkpoint=use_checkpoint, ) if not forward_only: - engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) + engine.initialize_optimizer(getattr(torch.optim, "SGD"), lr=1e-3) for _ in range(epoch): input_x = torch.randn((batch_size, dim), device=device) input_y = torch.randn((batch_size, dim), device=device) - logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) + logits = engine.forward_backward({"x": input_x, "y": input_y}, labels=labels, forward_only=forward_only) def run_worker(rank, world_size, port, model_cls, forward_only, master_func): - master_addr = 'localhost' + master_addr = "localhost" master_port = 29020 - os.environ['MASTER_ADDR'] = master_addr - os.environ['MASTER_PORT'] = str(master_port) + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) disable_existing_loggers() - launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=1, - tp_degree=1, - num_worker_threads=128, - device='cuda') + launch(dict(), rank, world_size, master_addr, master_port, "nccl", verbose=False) + ppg.set_global_info( + rank=rank, world_size=world_size, dp_degree=1, tp_degree=1, num_worker_threads=128, device="cuda" + ) # in rpc mode, only rank 0 is needed to be coded if rank == 0: @@ -125,8 +122,8 @@ def run_worker(rank, world_size, port, model_cls, forward_only, master_func): @pytest.mark.skip("skip due to CI torch version 1.11") -@parameterize('model_cls', [MLP, DAG_MLP]) -@parameterize('forward_only', [True, False]) +@parameterize("model_cls", [MLP, DAG_MLP]) +@parameterize("forward_only", [True, False]) @pytest.mark.dist @rerun_if_address_is_in_use() def test_pp_middleware_fwd(model_cls, forward_only): diff --git a/tests/test_legacy/test_pipeline/test_pipelinable.py b/tests/test_legacy/test_pipeline/test_pipelinable.py index 2ba5d0aa24d8..950cc68036ae 100644 --- a/tests/test_legacy/test_pipeline/test_pipelinable.py +++ b/tests/test_legacy/test_pipeline/test_pipelinable.py @@ -2,14 +2,13 @@ import torch from colossalai.legacy.pipeline.pipelinable import PipelinableContext -from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn NUM_CHUNKS = 1 PIPELINE_SIZE = 2 class MLP(torch.nn.Module): - def __init__(self, dim: int = 256): super().__init__() intermediate_dim = dim * 4 @@ -55,5 +54,5 @@ def test_pipelinable(): spawn(run_pipelinable, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_pipelinable() diff --git a/tests/test_legacy/test_pipeline/test_pipeline_process_group.py b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py index e6b95660279b..627aafb18e61 100644 --- a/tests/test_legacy/test_pipeline/test_pipeline_process_group.py +++ b/tests/test_legacy/test_pipeline/test_pipeline_process_group.py @@ -10,8 +10,8 @@ def run_worker(rank, args): - os.environ['MASTER_ADDR'] = args.master_addr - os.environ['MASTER_PORT'] = args.master_port + os.environ["MASTER_ADDR"] = args.master_addr + os.environ["MASTER_PORT"] = args.master_port device = args.device world_size = args.world_size @@ -20,17 +20,19 @@ def run_worker(rank, args): num_worker_threads = args.num_worker_threads host = args.master_addr port = args.master_port - backend = 'nccl' if device == 'cuda' else 'gloo' + backend = "nccl" if device == "cuda" else "gloo" disable_existing_loggers() launch(dict(), rank, world_size, host, int(port), backend, verbose=False) - ppg.set_global_info(rank=rank, - world_size=world_size, - dp_degree=dp_degree, - tp_degree=tp_degree, - num_worker_threads=num_worker_threads, - device=device) + ppg.set_global_info( + rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device, + ) if rpc_is_initialized(): rpc.shutdown() diff --git a/tests/test_legacy/test_tensor/common_utils/_utils.py b/tests/test_legacy/test_tensor/common_utils/_utils.py index b6fea28e4c8a..78bea6658364 100644 --- a/tests/test_legacy/test_tensor/common_utils/_utils.py +++ b/tests/test_legacy/test_tensor/common_utils/_utils.py @@ -13,7 +13,7 @@ def set_seed(seed): random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -27,12 +27,12 @@ def check_equal(A, B): def replace_parameter_add_grad(layer, weight=None, bias=None): if weight is not None: - delattr(layer, 'weight') - setattr(layer, 'weight', weight) + delattr(layer, "weight") + setattr(layer, "weight", weight) layer.weight.requires_grad = True if bias is not None: - delattr(layer, 'bias') - setattr(layer, 'bias', bias) + delattr(layer, "bias") + setattr(layer, "bias", bias) layer.bias.requires_grad = True @@ -47,12 +47,9 @@ def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: return True -def tensor_shard_equal(tensor: torch.Tensor, - shard: torch.Tensor, - rank: int, - world_size: int, - rtol: float = 1e-3, - atol: float = 1e-1): +def tensor_shard_equal( + tensor: torch.Tensor, shard: torch.Tensor, rank: int, world_size: int, rtol: float = 1e-3, atol: float = 1e-1 +): assert tensor.ndim == shard.ndim if tensor.shape == shard.shape: return tensor_equal(tensor, shard, rtol, atol) diff --git a/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py index b6d6bcee66ce..506244447054 100644 --- a/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py +++ b/tests/test_legacy/test_tensor/core/test_dist_spec_mgr.py @@ -48,17 +48,17 @@ def check_mem(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_mem() run() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_dist_spec_mgr(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_dist_spec_mgr(4) diff --git a/tests/test_legacy/test_tensor/test_parameter.py b/tests/test_legacy/test_tensor/test_parameter.py index 7a8694ff6789..5217e22cc422 100644 --- a/tests/test_legacy/test_tensor/test_parameter.py +++ b/tests/test_legacy/test_tensor/test_parameter.py @@ -9,26 +9,27 @@ @pytest.mark.skip def test_multiinheritance(): - colossalai.legacy.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.legacy.launch(config={}, rank=0, world_size=1, host="localhost", port=free_port(), backend="nccl") colo_param = ColoParameter(None, requires_grad=True) - assert colo_param.dist_spec.placement.value == 'r' + assert colo_param.dist_spec.placement.value == "r" assert isinstance(colo_param, ColoTensor) assert isinstance(colo_param, torch.nn.Parameter) # __deepcopy__ overload import copy + colo_param2 = copy.deepcopy(colo_param) assert isinstance(colo_param2, ColoParameter) assert tensor_equal(colo_param.data, colo_param2.data) assert colo_param.requires_grad == colo_param2.requires_grad # __repr__ overload - assert 'ColoParameter' in str(colo_param) + assert "ColoParameter" in str(colo_param) # __torch_function__ clone_param = torch.clone(colo_param) assert isinstance(clone_param, ColoTensor) -if __name__ == '__main__': +if __name__ == "__main__": test_multiinheritance() diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py index 84652093a9fd..a5a2d38577dc 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_p2p.py @@ -8,12 +8,10 @@ from colossalai.legacy.communication import ( recv_backward, recv_forward, - recv_obj_meta, send_backward, send_backward_recv_forward, send_forward, send_forward_recv_backward, - send_obj_meta, ) from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc @@ -39,10 +37,10 @@ def check_forward(output_tensor, rank, logger): tensor = output_tensor.clone() else: tensor = recv_forward(output_tensor.shape) - logger.info('Rank {} received forward. Correct tensor: {}'.format(rank, check_equal(tensor, output_tensor))) + logger.info("Rank {} received forward. Correct tensor: {}".format(rank, check_equal(tensor, output_tensor))) if not gpc.is_last_rank(ParallelMode.PIPELINE): send_forward(tensor) - logger.info('Rank {} sent forward.'.format(rank)) + logger.info("Rank {} sent forward.".format(rank)) def check_backward(output_grad, rank, logger): @@ -51,22 +49,26 @@ def check_backward(output_grad, rank, logger): grad = output_grad.clone() else: grad = recv_backward(output_grad.shape) - logger.info('Rank {} received backward. Correct grad: {}'.format(rank, check_equal(grad, output_grad))) + logger.info("Rank {} received backward. Correct grad: {}".format(rank, check_equal(grad, output_grad))) if not gpc.is_first_rank(ParallelMode.PIPELINE): send_backward(grad) - logger.info('Rank {} sent backward.'.format(rank)) + logger.info("Rank {} sent backward.".format(rank)) def check_forward_backward(output_tensor, output_grad, rank, logger): dist.barrier() if not gpc.is_first_rank(ParallelMode.PIPELINE): tensor = send_backward_recv_forward(output_grad, output_tensor.shape) - logger.info('Rank {} sent backward received forward. Correct tensor: {}'.format( - rank, check_equal(tensor, output_tensor))) + logger.info( + "Rank {} sent backward received forward. Correct tensor: {}".format( + rank, check_equal(tensor, output_tensor) + ) + ) if not gpc.is_last_rank(ParallelMode.PIPELINE): grad = send_forward_recv_backward(output_tensor, output_grad.shape) - logger.info('Rank {} sent forward received backward. Correct grad: {}'.format( - rank, check_equal(grad, output_grad))) + logger.info( + "Rank {} sent forward received backward. Correct grad: {}".format(rank, check_equal(grad, output_grad)) + ) def check_comm(size, rank, prev_rank, next_rank, logger): @@ -84,13 +86,13 @@ def check_comm(size, rank, prev_rank, next_rank, logger): def run_check(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") logger = get_dist_logger() rank = gpc.get_global_rank() prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - logger.info('Rank {0}: prev rank {1}, next rank {2}'.format(rank, prev_rank, next_rank)) - logger.info('Distributed environment is initialized.') + logger.info("Rank {0}: prev rank {1}, next rank {2}".format(rank, prev_rank, next_rank)) + logger.info("Distributed environment is initialized.") check_comm(world_size, rank, prev_rank, next_rank, logger) gpc.destroy() @@ -104,5 +106,5 @@ def test_p2p(): spawn(run_check, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_p2p() diff --git a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py index fd94c279b6fb..cd7fcfe5635d 100644 --- a/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_legacy/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -23,7 +23,7 @@ def run_schedule(rank, world_size, port): - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # build model model = resnet18(num_classes=10) @@ -33,20 +33,23 @@ def run_schedule(rank, world_size, port): elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: class Flatten(nn.Module): - def forward(self, x): return torch.flatten(x, 1) model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) - print_rank_0('model is created') - - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), - ])) + print_rank_0("model is created") + + train_dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), + ] + ), + ) train_dataloader = get_dataloader( dataset=train_dataset, @@ -83,5 +86,5 @@ def test_pipeline_schedule(): spawn(run_schedule, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_schedule() diff --git a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py index 4a240533474c..d19b12a5b044 100644 --- a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -16,16 +16,15 @@ CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) -@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'nested_model']) +@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"]) def run_trainer(model_name): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model_builder() optimizer = optimizer_class(model.parameters(), lr=1e-3) - engine, train_dataloader, *_ = colossalai.legacy.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *_ = colossalai.legacy.initialize( + model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader + ) logger = get_dist_logger() logger.info("engine is built", ranks=[0]) @@ -35,22 +34,21 @@ def run_trainer(model_name): logger.info("trainer is built", ranks=[0]) logger.info("start training", ranks=[0]) - trainer.fit(train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - epochs=NUM_EPOCHS, - max_steps=3, - display_progress=True, - test_interval=5) + trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=NUM_EPOCHS, + max_steps=3, + display_progress=True, + test_interval=5, + ) torch.cuda.empty_cache() def run_dist(rank, world_size, port): - colossalai.legacy.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) @pytest.mark.dist @@ -60,5 +58,5 @@ def test_trainer_no_pipeline(): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_trainer_no_pipeline() diff --git a/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py index 521b2f32f22d..0b34a79f96dd 100644 --- a/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py +++ b/tests/test_legacy/test_trainer/test_trainer_with_pipe_schedule.py @@ -29,12 +29,9 @@ def run_trainer_with_pipeline(rank, world_size, port): - colossalai.legacy.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.legacy.launch( + config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl" + ) # build model model = resnet18(num_classes=10) @@ -44,35 +41,35 @@ def run_trainer_with_pipeline(rank, world_size, port): elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1: class Flatten(nn.Module): - def forward(self, x): return torch.flatten(x, 1) model = nn.Sequential(model.layer3, model.layer4, model.avgpool, Flatten(), model.fc) # build dataloaders - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) - - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - drop_last=True) + train_dataset = CIFAR10( + root=Path(os.environ["DATA"]), + download=True, + transform=transforms.Compose( + [ + transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + ] + ), + ) + + train_dataloader = get_dataloader( + dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True + ) # build optimizer optimizer = Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() - engine, train_dataloader, *args = colossalai.legacy.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) + engine, train_dataloader, *args = colossalai.legacy.initialize( + model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader + ) logger = get_dist_logger() logger.info("engine is built", ranks=[0]) @@ -82,11 +79,9 @@ def forward(self, x): logger.info("start training", ranks=[0]) - trainer.fit(train_dataloader=train_dataloader, - epochs=NUM_EPOCHS, - max_steps=3, - display_progress=True, - test_interval=5) + trainer.fit( + train_dataloader=train_dataloader, epochs=NUM_EPOCHS, max_steps=3, display_progress=True, test_interval=5 + ) gpc.destroy() torch.cuda.empty_cache() @@ -98,5 +93,5 @@ def test_trainer_with_pipeline(): spawn(run_trainer_with_pipeline, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_trainer_with_pipeline() diff --git a/tests/test_legacy/test_utils/test_activation_checkpointing.py b/tests/test_legacy/test_utils/test_activation_checkpointing.py index 19984ae120b5..3303f610ee82 100644 --- a/tests/test_legacy/test_utils/test_activation_checkpointing.py +++ b/tests/test_legacy/test_utils/test_activation_checkpointing.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import pytest import torch import torch.nn.functional as F @@ -44,20 +43,19 @@ def forward_inplace(x, weight): @parameterize("use_reentrant", [True, False]) @parameterize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload, use_reentrant): - # as seed manager is singleton # if we don't reset seeds here, # other tests might affect this test reset_seeds() # We put initialization here to avoid change cuda rng state below - inputs = torch.rand(2, 2, requires_grad=True, device='cuda') - weight = torch.rand(2, 4, requires_grad=True, device='cuda') + inputs = torch.rand(2, 2, requires_grad=True, device="cuda") + weight = torch.rand(2, 4, requires_grad=True, device="cuda") # Get a copy of input tensors - inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') + inputs_ = torch.empty(2, 2, requires_grad=True, device="cuda") inputs_.data.copy_(inputs.data) - weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') + weight_ = torch.empty(2, 4, requires_grad=True, device="cuda") weight_.data.copy_(weight.data) add_seed(ParallelMode.GLOBAL, 1024) @@ -83,7 +81,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant): loss = out.sum() loss.backward() - assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match" torch.cuda.empty_cache() # Extra test for use_reentrant=False @@ -110,7 +108,7 @@ def test_activation_checkpointing(cpu_offload, use_reentrant): loss = out.sum() loss.backward() - assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' + assert torch.all(inputs.grad == inputs_.grad), "Gradient of the input does not match" torch.cuda.empty_cache() # as seed manager is singleton diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py index 88cd89a217fe..c07ff132b79e 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_1d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) + config = dict( + parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py index 591cd714fc65..2ec1facf21b1 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_2d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) + config = dict( + parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py index b165b4276f10..a6bf702a8482 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_2p5d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) + config = dict( + parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py index 2ce054d33b2d..12d928312969 100644 --- a/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_legacy/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -38,7 +38,9 @@ def check_equal(A, B): def check_checkpoint_3d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) + config = dict( + parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")), + ) disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py index 2e25dc773b68..9416ac86e325 100644 --- a/tests/test_legacy/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -14,7 +14,7 @@ def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): def run_dist(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity() @@ -24,5 +24,5 @@ def test_memory_utils(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_memory_utils(world_size=2) diff --git a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py index 918f174aba76..b5f2be705890 100644 --- a/tests/test_legacy/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_legacy/test_utils/test_norm_gradient_clipping.py @@ -28,20 +28,20 @@ def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None: grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()] else: grad = p.grad - assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}' + assert torch.allclose(grad, colo_p.grad), f"diff: {torch.abs(grad - colo_p.grad)}" -@parameterize('dtype', [torch.float]) -@parameterize('device', ['mixed', 'cuda', 'cpu']) -@parameterize('norm_type', [2.0, 3.0, float('inf')]) +@parameterize("dtype", [torch.float]) +@parameterize("device", ["mixed", "cuda", "cpu"]) +@parameterize("norm_type", [2.0, 3.0, float("inf")]) def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float): - print(f'{world_size}, {dtype}, {device}, {norm_type}') + print(f"{world_size}, {dtype}, {device}, {norm_type}") cuda_device = get_current_device() devices = [cuda_device] * 4 - if device == 'cpu': - devices = [torch.device('cpu')] * 4 - elif device == 'mixed': - devices = [cuda_device] * 2 + [torch.device('cpu')] * 2 + if device == "cpu": + devices = [torch.device("cpu")] * 4 + elif device == "mixed": + devices = [cuda_device] * 2 + [torch.device("cpu")] * 2 pg = ProcessGroup(tp_degree=world_size) params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)] colo_params = [ @@ -55,24 +55,24 @@ def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_ty shard_param(colo_params[2]) torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type) colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type) - assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}' + assert close(torch_norm, colo_norm), f"diff: {abs(torch_norm-colo_norm)}" for p, colo_p in zip(params, colo_params): check_grad_equal(p, colo_p) def run_dist(rank, world_size, port): disable_existing_loggers() - colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_grad_clip_norm(world_size=world_size) @pytest.mark.skip("this need to be updated") @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_zero_clip_grad(world_size: int): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_clip_grad(2) diff --git a/tests/test_legacy/test_zero/test_commons.py b/tests/test_legacy/test_zero/test_commons.py index 42a9f1eecb95..741f519e1376 100644 --- a/tests/test_legacy/test_zero/test_commons.py +++ b/tests/test_legacy/test_zero/test_commons.py @@ -7,29 +7,29 @@ def run_tensor_move(rank, world_size, port): - colossalai.legacy.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.legacy.launch(config={}, rank=0, world_size=world_size, host="localhost", port=port, backend="nccl") src_t = torch.ones(2, 3).cuda() tgt_t = torch.zeros(2, 3) colo_model_data_tensor_move(src_t, tgt_t) - assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + assert torch.sum(tgt_t) == 6.0, f"{torch.sum(tgt_t.payload)} vs. 6.0" src_t = torch.ones(2, 3) tgt_t = torch.zeros(2, 3).cuda().half() colo_model_data_tensor_move(src_t, tgt_t) # the src_t has been removed - assert (src_t.numel() == 0) - assert (torch.sum(tgt_t) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + assert src_t.numel() == 0 + assert torch.sum(tgt_t) == 6.0, f"{torch.sum(tgt_t.payload)} vs. 6.0" src_t = ShardedTensor(torch.ones(2, 3)) tgt_t = ShardedTensor(torch.zeros(2, 3).cuda().half()) colo_model_data_tensor_move(src_t, tgt_t) - assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + assert torch.sum(tgt_t.payload) == 6.0, f"{torch.sum(tgt_t.payload)} vs. 6.0" - assert (tgt_t.device.type == 'cuda') - colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu')) - assert (tgt_t.device.type == 'cpu') + assert tgt_t.device.type == "cuda" + colo_model_data_tensor_move_inline(tgt_t, torch.device("cpu")) + assert tgt_t.device.type == "cpu" @rerun_if_address_is_in_use() @@ -37,5 +37,5 @@ def test_tensor_move(): spawn(run_tensor_move, 1) -if __name__ == '__main__': +if __name__ == "__main__": test_tensor_move() diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 9c84a99cd549..8742e5f41136 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -17,11 +17,11 @@ def run_test(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") expert_module = nn.Linear expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device()) - MOE_CONTEXT.setup(42) # MOE initialization + MOE_CONTEXT.setup(42) # MOE initialization noisy_func = UniformNoiseGenerator() router = Top1Router(noisy_func=noisy_func) num_experts_list = [1, 2, 4] @@ -67,5 +67,5 @@ def test_grad_handler(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_handler() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index c096b6075005..7a9c551d679d 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -23,12 +23,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # Here we do not need TF32, since it brings absolute error on results torch.backends.cuda.matmul.allow_tf32 = False - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) - MOE_CONTEXT.setup(42) # MOE environment initialization + MOE_CONTEXT.setup(42) # MOE environment initialization MOE_CONTEXT.reset_loss() - torch.manual_seed(rs + local_rank) # set each process has different random seed + torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True) @@ -46,7 +46,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f old_out, _ = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) - old_out.backward(grad) # get gradient + old_out.backward(grad) # get gradient # save all results o_tk_grad = tokens.grad.data.clone() @@ -57,7 +57,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.gate_weight.grad.zero_() layer.use_kernel = True - new_out, _ = layer(tokens) # get outputs through colossal kernel + new_out, _ = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: check_equal(old_out, new_out) @@ -65,7 +65,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f check_equal(old_out, new_out, 1e-2) # forward function passed - new_out.backward(grad) # get new type gradient + new_out.backward(grad) # get new type gradient n_tk_grad = tokens.grad.data.clone() n_gt_grad = layer.gate_weight.grad.data.clone() @@ -92,5 +92,5 @@ def test_moe_kernel(rs, hidden_size, data_type, router): spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_kernel(2, 256, torch.float16, Top2Router) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 8a0283ba71fc..b7024f32b1cf 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -17,11 +17,11 @@ def exam_moe_checkpoint(): with ColoInitContext(device=get_current_device()): model = MoeModel(checkpoint=True) - save_moe_model(model, 'temp_path.pth') + save_moe_model(model, "temp_path.pth") with ColoInitContext(device=get_current_device()): other_model = MoeModel(checkpoint=True) - load_moe_model(other_model, 'temp_path.pth') + load_moe_model(other_model, "temp_path.pth") state_0 = model.state_dict() state_1 = other_model.state_dict() @@ -30,11 +30,11 @@ def exam_moe_checkpoint(): assert torch.equal(u.data, v.data) if dist.get_rank() == 0: - os.remove('temp_path.pth') + os.remove("temp_path.pth") def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) exam_moe_checkpoint() @@ -46,5 +46,5 @@ def test_moe_checkpoint(world_size): spawn(_run_dist) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_checkpoint(world_size=4) diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py index 555338fcf9fc..488573b733b1 100644 --- a/tests/test_moe/test_moe_colo_init.py +++ b/tests/test_moe/test_moe_colo_init.py @@ -9,17 +9,16 @@ from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_tensor.common_utils import debug_print from tests.test_zero.test_legacy.common import CONFIG -@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("init_device_type", ["cpu", "cuda"]) def exam_moe_colo_init(init_device_type): world_size = dist.get_world_size() - if init_device_type == 'cuda': + if init_device_type == "cuda": init_device = get_current_device() - elif init_device_type == 'cpu': + elif init_device_type == "cpu": init_device = torch.device("cpu") else: raise NotImplementedError("Unknown device found.") @@ -40,7 +39,7 @@ def exam_moe_colo_init(init_device_type): def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) exam_moe_colo_init() @@ -52,5 +51,5 @@ def test_moe_colo_init(world_size): spawn(_run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 6dc3f5f18b6d..300fb6c99b7b 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -16,11 +16,11 @@ def run_test(rank, world_size, port): world_size = 4 - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") expert_module = nn.Linear expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device()) - MOE_CONTEXT.setup(42) # MOE environment initialization + MOE_CONTEXT.setup(42) # MOE environment initialization exp0 = Experts(expert_module, 1, **expert_factor) exp1 = Experts(expert_module, 2, **expert_factor) exp2 = Experts(expert_module, 4, **expert_factor) @@ -64,5 +64,5 @@ def test_moe_initialization(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_initialization() diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 79722f9f4056..c48f9a3557ce 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -15,20 +15,15 @@ class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): - class TestSubModule(CheckpointModule): - def __init__(self): super().__init__(checkpoint) expert_cls = nn.Linear expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule(dim_model=16, - num_experts=8, - use_residual=True, - expert_cls=expert_cls, - **expert_args_dict) + self.moe = MoeModule( + dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict + ) self.proj = nn.Linear(16, 4) def _forward(self, x): @@ -50,49 +45,52 @@ def forward(self, x): return x -@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("init_device_type", ["cpu", "cuda"]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_moe_zero_init(init_device_type, shard_strategy_class): - logger = get_dist_logger("test_moe_zero_init") + get_dist_logger("test_moe_zero_init") - if init_device_type == 'cuda': + if init_device_type == "cuda": init_device = get_current_device() - elif init_device_type == 'cpu': + elif init_device_type == "cpu": init_device = torch.device("cpu") else: raise NotImplementedError("Unknown device found.") model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor): + with ZeroInitContext( + target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor, + ): model = MoeModel(checkpoint=True) for name, param in model.named_parameters(): - assert hasattr(param, 'colo_attr') + assert hasattr(param, "colo_attr") # the parameters in moe experts and its gate should not be sharded - if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): + if ("experts" in name) or ("gate" in name) or ("residual_combine" in name): assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) else: assert param.colo_attr.sharded_data_tensor.is_sharded # the parameters in moe experts is not replicated - if 'experts' in name: + if "experts" in name: assert not param.colo_attr.is_replicated else: assert param.colo_attr.is_replicated if param.colo_attr.param_is_sharded: - assert param.colo_attr.data_payload.device.type == init_device.type, \ - f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + assert ( + param.colo_attr.data_payload.device.type == init_device.type + ), f"{param.colo_attr.data_payload.device.type} vs. {init_device.type}" else: - assert param.colo_attr.data_payload.device.type == 'cuda' + assert param.colo_attr.data_payload.device.type == "cuda" def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) run_moe_zero_init() @@ -104,5 +102,5 @@ def test_moe_zero_init(world_size): spawn(_run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_init(world_size=2) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 595d4374df6f..724d70d77bc6 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -21,13 +21,13 @@ def run_model_test(enable_autocast, shard_strategy_class): shard_strategy = shard_strategy_class() - get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") _, train_dataloader, _, optimizer_class, _ = get_components_func() criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): + with ZeroInitContext( + target_device=torch.device("cuda", torch.cuda.current_device()), shard_strategy=shard_strategy, shard_param=True + ): zero_model = MoeModel(checkpoint=True) zero_model = ShardedModelV2(zero_model, shard_strategy) @@ -54,7 +54,7 @@ def run_model_test(enable_autocast, shard_strategy_class): def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) run_model_test() @@ -66,5 +66,5 @@ def test_moe_zero_model(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_model(world_size=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 35fde6f10f3f..bb9822daee05 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -43,31 +43,33 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler): @parameterize("cpu_offload", [True]) -@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug +@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug @parameterize("reuse_fp16_shard", [True, False]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2(cpu_offload, - shard_strategy_class, - use_cpuadam, - reuse_fp16_shard, - gpu_margin_mem_ratio=0.0): +def _run_test_sharded_optim_v2( + cpu_offload, shard_strategy_class, use_cpuadam, reuse_fp16_shard, gpu_margin_mem_ratio=0.0 +): shard_strategy = shard_strategy_class() if use_cpuadam and cpu_offload is False: return MOE_CONTEXT.reset_loss() - get_components_func = non_distributed_component_funcs.get_callable('hanging_param_model') + get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model") _, train_dataloader, _, optimizer_class, _ = get_components_func() criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) - with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): + with ZeroInitContext( + target_device=torch.device("cpu") if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True, + ): zero_model = MoeModel(checkpoint=True) - zero_model = ShardedModelV2(zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=reuse_fp16_shard) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy="cpu" if cpu_offload else "cuda", + reuse_fp16_shard=reuse_fp16_shard, + ) # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): @@ -82,12 +84,11 @@ def _run_test_sharded_optim_v2(cpu_offload, optimizer_class = CPUAdam optim = optimizer_class(model.parameters(), lr=1e-3) sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - sharded_optim, - initial_scale=2**5, - gpu_margin_mem_ratio=gpu_margin_mem_ratio) + sharded_optim = ShardedOptimizerV2( + zero_model, sharded_optim, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio + ) - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False) apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) apex_grad_handler = MoeGradientHandler(model) @@ -103,7 +104,7 @@ def _run_test_sharded_optim_v2(cpu_offload, def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") MOE_CONTEXT.setup(seed=42) _run_test_sharded_optim_v2() @@ -116,5 +117,5 @@ def test_moe_zero_optim(world_size): spawn(_run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_moe_zero_optim(world_size=4) diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 2186a421fe00..8131ea3234d8 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -10,16 +10,25 @@ from colossalai.utils import get_current_device, multi_tensor_applier -_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), - (torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), - (torch.bfloat16, torch.bfloat16)] - -_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), - (torch.half, torch.half)] +_FUSED_ALLOWED_P_G_TYPES = [ + (torch.float, torch.half), + (torch.float, torch.float), + (torch.half, torch.float), + (torch.half, torch.half), + (torch.bfloat16, torch.float), + (torch.float, torch.bfloat16), + (torch.bfloat16, torch.bfloat16), +] + +_CPU_ALLOWED_P_G_TYPES = [ + (torch.float, torch.half), + (torch.float, torch.float), + (torch.half, torch.float), + (torch.half, torch.half), +] class AdamKernel: - def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: self.lr = lr self.beta1 = beta1 @@ -34,7 +43,6 @@ def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_av class TorchAdamKernel(AdamKernel): - def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): bias_correction1 = 1 - self.beta1**step bias_correction2 = 1 - self.beta2**step @@ -57,36 +65,68 @@ def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_av class FusedAdamKernel(AdamKernel): - def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() self.fused_adam = fused_optim.multi_tensor_adam self.dummy_overflow_buf = torch.cuda.IntTensor([0]) def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): - multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]], - self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay, - -1) + multi_tensor_applier( + self.fused_adam, + self.dummy_overflow_buf, + [[grad], [param], [exp_avg], [exp_avg_sq]], + self.lr, + self.beta1, + self.beta2, + self.eps, + step, + self.use_adamw, + True, + self.weight_decay, + -1, + ) class CPUAdamKernel(AdamKernel): - def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) from colossalai.kernel.op_builder import CPUAdamBuilder + cpu_optim = CPUAdamBuilder().load() self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): - self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1), - grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1) - - -def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype, - g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float): + self.cpu_adam_op.step( + step, + self.lr, + self.beta1, + self.beta2, + self.eps, + self.weight_decay, + True, + param.view(-1), + grad.view(-1), + exp_avg.view(-1), + exp_avg_sq.view(-1), + -1, + ) + + +def check_adam_kernel( + kernel: Type[AdamKernel], + adamw: bool, + weight_decay: float, + p_dtype: torch.dtype, + g_dtype: torch.dtype, + device: torch.device, + n_steps: int, + rtol: float, + atol: float, +): lr = 1e-3 beta1, beta2 = 0.9, 0.999 eps = 1e-8 @@ -109,9 +149,9 @@ def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol) -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) -@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES) +@pytest.mark.parametrize("adamw", [False, True]) +@pytest.mark.parametrize("weight_decay", [0.0, 0.1]) +@pytest.mark.parametrize("p_dtype, g_dtype", _FUSED_ALLOWED_P_G_TYPES) def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): rtol, atol = 1e-5, 1e-8 if p_dtype is torch.float16 or g_dtype is torch.float16: @@ -121,11 +161,11 @@ def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) -@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES) +@pytest.mark.parametrize("adamw", [False, True]) +@pytest.mark.parametrize("weight_decay", [0.0, 0.1]) +@pytest.mark.parametrize("p_dtype, g_dtype", _CPU_ALLOWED_P_G_TYPES) def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): rtol, atol = 1e-5, 1e-8 if p_dtype is torch.float16 or g_dtype is torch.float16: rtol, atol = 1e-3, 1e-3 - check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol) + check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device("cpu"), 3, rtol, atol) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py index 0f72bc134809..59b40a0afa3c 100644 --- a/tests/test_optimizer/test_adam_optim.py +++ b/tests/test_optimizer/test_adam_optim.py @@ -10,17 +10,17 @@ from tests.kit.model_zoo import model_zoo _ALLOWED_OPTIM_DEVICES = [ - (FusedAdam, torch.device('cuda:0')), - (CPUAdam, torch.device('cpu')), - (CPUAdam, torch.device('cuda:0')), - (HybridAdam, torch.device('cpu')), - (HybridAdam, torch.device('cuda:0')), + (FusedAdam, torch.device("cuda:0")), + (CPUAdam, torch.device("cpu")), + (CPUAdam, torch.device("cuda:0")), + (HybridAdam, torch.device("cpu")), + (HybridAdam, torch.device("cuda:0")), ] _ALLOWED_P_G_TYPES = [ - (torch.float, torch.float), # pure fp32 - (torch.float, torch.half), # fp16 amp - (torch.float, torch.bfloat16), # bfloat16 amp + (torch.float, torch.float), # pure fp32 + (torch.float, torch.half), # fp16 amp + (torch.float, torch.bfloat16), # bfloat16 amp # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 ] @@ -53,12 +53,17 @@ def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> p.data = orig_p -@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES) -@pytest.mark.parametrize('adamw', [False, True]) -@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES) -def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device, - adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None: - model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values())) +@pytest.mark.parametrize("optim_cls, device", _ALLOWED_OPTIM_DEVICES) +@pytest.mark.parametrize("adamw", [False, True]) +@pytest.mark.parametrize("p_dtype, g_dtype", _ALLOWED_P_G_TYPES) +def test_adam_optim_on_bert( + optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], + device: torch.device, + adamw: bool, + p_dtype: torch.dtype, + g_dtype: torch.dtype, +) -> None: + model_fn, *_ = next(iter(model_zoo.get_sub_registry("transformers_bert_for_sequence_classification").values())) torch_model = model_fn().to(device) model = deepcopy(torch_model).to(p_dtype) lr = 1e-3 diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index 5d794ac2dd1a..a68a9c51855f 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -1,4 +1,3 @@ -import pytest import torch from colossalai.nn.optimizer import CPUAdam, HybridAdam @@ -15,23 +14,22 @@ def move_some_params_to_cuda(model, torch_model): def check_params_equal(model, torch_model): for p, torch_p in zip(model.parameters(), torch_model.parameters()): - assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}' + assert torch.allclose(p, torch_p, atol=1e-3), f"diff: {torch.abs(p - torch_p)}" @clear_cache_before_run() -@parameterize('nvme_offload_fraction', [0.0, 0.5, 1.0]) -@parameterize('nvme_offload_dir', ['./offload', None]) -@parameterize('adam_cls', [CPUAdam, HybridAdam]) +@parameterize("nvme_offload_fraction", [0.0, 0.5, 1.0]) +@parameterize("nvme_offload_dir", ["./offload", None]) +@parameterize("adam_cls", [CPUAdam, HybridAdam]) def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): - get_components_func = non_distributed_component_funcs.get_callable('simple_net') + get_components_func = non_distributed_component_funcs.get_callable("simple_net") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model_builder() torch_model = model_builder() move_some_params_to_cuda(model, torch_model) - optimizer = adam_cls(model.parameters(), - lr=0.1, - nvme_offload_fraction=nvme_offload_fraction, - nvme_offload_dir=nvme_offload_dir) + optimizer = adam_cls( + model.parameters(), lr=0.1, nvme_offload_fraction=nvme_offload_fraction, nvme_offload_dir=nvme_offload_dir + ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.1) with torch.no_grad(): for p, torch_p in zip(model.parameters(), torch_model.parameters()): @@ -45,5 +43,5 @@ def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): check_params_equal(model, torch_model) -if __name__ == '__main__': - test_nvme_adam(0.5, './offload', CPUAdam) +if __name__ == "__main__": + test_nvme_adam(0.5, "./offload", CPUAdam) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 71946f6b988a..1665711ceeef 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -22,30 +22,30 @@ def check_p2p_communication(): if rank == 0: p2p.send_forward(tensor) p2p.send_forward([tensor]) - p2p.send_forward({'tensor': tensor}) + p2p.send_forward({"tensor": tensor}) else: obj = p2p.recv_forward() assert torch.equal(obj, tensor) obj = p2p.recv_forward() assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) obj = p2p.recv_forward() - assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor) + assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) if rank == 1: p2p.send_backward(tensor) p2p.send_backward([tensor]) - p2p.send_backward({'tensor': tensor}) + p2p.send_backward({"tensor": tensor}) else: obj = p2p.recv_backward() assert torch.equal(obj, tensor) obj = p2p.recv_backward() assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) obj = p2p.recv_backward() - assert type(obj) == dict and 'tensor' in obj and torch.equal(obj['tensor'], tensor) + assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") check_p2p_communication() @@ -55,5 +55,5 @@ def test_pipeline_p2p(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_p2p() diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 0cbb852b97a0..3723c9c1014a 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -4,36 +4,42 @@ def test_t5_pipeline_distribution(): num_test_cases = 8 test_dict = { - 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], - 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], - 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], - 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + "num_encoder_layers": [2, 1, 3, 2, 3, 2, 10, 5], + "num_decoder_layers": [2, 8, 0, 2, 1, 5, 6, 22], + "num_stages": [2, 2, 2, 4, 4, 4, 8, 8], + "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } for i in range(num_test_cases): - _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i], - test_dict['num_decoder_layers'][i], - test_dict['num_stages'][i]) - assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + _, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage def test_t5_pipeline_layers(): num_test_cases = 4 test_dict = { - 'num_encoder_layers': [2, 3, 2, 4], - 'num_decoder_layers': [2, 0, 2, 8], - 'num_stages': [2, 2, 4, 4], - 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], - [[0, 4], [0, 3], [3, 6], [6, 8]]] + "num_encoder_layers": [2, 3, 2, 4], + "num_decoder_layers": [2, 0, 2, 8], + "num_stages": [2, 2, 4, 4], + "layers_per_stage": [ + [[0, 2], [0, 2]], + [[0, 1], [1, 3]], + [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]], + ], } for i in range(num_test_cases): layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers( - test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) - for stage in range(test_dict['num_stages'][i]): - start_idx, end_idx = test_dict['layers_per_stage'][i][stage] - predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage, - decoder_starting_stage) + for stage in range(test_dict["num_stages"][i]): + start_idx, end_idx = test_dict["layers_per_stage"][i][stage] + predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index( + layers_per_stage, stage, decoder_starting_stage + ) assert start_idx == predicted_start assert end_idx == predicted_end diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index 395519e97898..f6be8f6feac2 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -4,41 +4,47 @@ def test_whisper_pipeline_distribution(): num_test_cases = 8 test_dict = { - 'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5], - 'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22], - 'num_stages': [2, 2, 2, 4, 4, 4, 8, 8], - 'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2] + "num_encoder_layers": [2, 1, 3, 2, 3, 2, 10, 5], + "num_decoder_layers": [2, 8, 0, 2, 1, 5, 6, 22], + "num_stages": [2, 2, 2, 4, 4, 4, 8, 8], + "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } for i in range(num_test_cases): - _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i], - test_dict['num_decoder_layers'][i], - test_dict['num_stages'][i]) - assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage + _, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage def test_whisper_pipeline_layers(): num_test_cases = 4 test_dict = { - 'num_encoder_layers': [2, 3, 2, 4], - 'num_decoder_layers': [2, 0, 2, 8], - 'num_stages': [2, 2, 4, 4], - 'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]], - [[0, 4], [0, 3], [3, 6], [6, 8]]] + "num_encoder_layers": [2, 3, 2, 4], + "num_decoder_layers": [2, 0, 2, 8], + "num_stages": [2, 2, 4, 4], + "layers_per_stage": [ + [[0, 2], [0, 2]], + [[0, 1], [1, 3]], + [[0, 1], [1, 2], [0, 1], [1, 2]], + [[0, 4], [0, 3], [3, 6], [6, 8]], + ], } for i in range(num_test_cases): layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers( - test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i]) - - for stage in range(test_dict['num_stages'][i]): - start_idx, end_idx = test_dict['layers_per_stage'][i][stage] - predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage, - decoder_starting_stage) + test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i] + ) + + for stage in range(test_dict["num_stages"][i]): + start_idx, end_idx = test_dict["layers_per_stage"][i][stage] + predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index( + layers_per_stage, stage, decoder_starting_stage + ) assert start_idx == predicted_start assert end_idx == predicted_end -if __name__ == '__main__': +if __name__ == "__main__": test_whisper_pipeline_distribution() test_whisper_pipeline_layers() diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index a995d17e5da6..f181453eaed5 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -16,7 +16,6 @@ class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(4, 8) @@ -40,19 +39,20 @@ def forward(self, x): return x -def pp_linear_fwd(forward, - data: torch.Tensor = None, - input_obj: torch.Tensor = None, - stage_mgr: PipelineStageManager = None, - num_chunks: int = None, - model_chunk_id: int = None): - +def pp_linear_fwd( + forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + num_chunks: int = None, + model_chunk_id: int = None, +): if stage_mgr.is_first_stage() and model_chunk_id == 0: - return {'input_obj': forward(data)} + return {"input_obj": forward(data)} elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: return forward(input_obj) else: - return {'input_obj': forward(input_obj)} + return {"input_obj": forward(input_obj)} @parameterize("num_micro_batches", [4, 8, 12]) @@ -84,10 +84,11 @@ def examine_pp(num_micro_batches): if idx % (world_size) == local_rank: sub_model._forward = sub_model.forward sub_model.forward = MethodType( - partial(pp_linear_fwd, - stage_mgr=stage_manager, - num_chunks=NUM_CHUNKS, - model_chunk_id=len(sharded_model)), sub_model._forward) + partial( + pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model) + ), + sub_model._forward, + ) sharded_model.append(sub_model.cuda()) # create optimizer @@ -109,16 +110,13 @@ def examine_pp(num_micro_batches): torch_loss = criterion(torch_output, _) torch_loss.backward() - pp_ret = schedule.forward_backward_step(sharded_model, - iter(input_list), - criterion, - pp_optimizer, - return_loss=True, - return_outputs=True) + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + ) # check loss if stage_manager.is_last_stage(): - assert torch.allclose(torch_loss, pp_ret['loss']) + assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients torch_grad = [] @@ -147,7 +145,7 @@ def examine_pp(num_micro_batches): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") examine_pp() @@ -157,5 +155,5 @@ def test_pp(): spawn(run_dist, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_pp() diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 41b535573c39..1d77edc2db11 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -16,7 +16,6 @@ class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(4, 8) @@ -28,17 +27,15 @@ def forward(self, x): return x -def pp_linear_fwd(forward, - data: torch.Tensor = None, - input_obj: torch.Tensor = None, - stage_mgr: PipelineStageManager = None): - +def pp_linear_fwd( + forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None +): if stage_mgr.is_first_stage(): - return {'input_obj': forward(data)} + return {"input_obj": forward(data)} elif stage_mgr.is_last_stage(): return forward(input_obj) else: - return {'input_obj': forward(input_obj)} + return {"input_obj": forward(input_obj)} def examine_pp(): @@ -89,16 +86,13 @@ def examine_pp(): torch_loss = criterion(torch_output, _) torch_loss.backward() - pp_ret = schedule.forward_backward_step(sharded_model, - iter(input_list), - criterion, - pp_optimizer, - return_loss=True, - return_outputs=True) + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + ) # check loss if stage_manager.is_last_stage(): - assert torch.allclose(torch_loss, pp_ret['loss']) + assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients torch_grad = [] @@ -120,7 +114,7 @@ def examine_pp(): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") examine_pp() @@ -130,5 +124,5 @@ def test_pp(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_pp() diff --git a/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py b/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py index 4c23a23ebaba..462355ee470b 100644 --- a/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py +++ b/tests/test_pipeline/test_schedule/test_pipeline_schedule_utils.py @@ -8,9 +8,9 @@ def test_get_batch_size(): assert get_batch_size(tensor) == 2 assert get_batch_size([tensor]) == 2 assert get_batch_size((1, tensor)) == 2 - assert get_batch_size({'tensor': tensor}) == 2 - assert get_batch_size({'dummy': [1], 'tensor': tensor}) == 2 - assert get_batch_size({'tensor': [tensor]}) == 2 + assert get_batch_size({"tensor": tensor}) == 2 + assert get_batch_size({"dummy": [1], "tensor": tensor}) == 2 + assert get_batch_size({"tensor": [tensor]}) == 2 def test_get_micro_batch(): @@ -26,12 +26,12 @@ def test_get_micro_batch(): micro_batch = get_micro_batch([x, y], 1, 1) assert torch.equal(micro_batch[0], x[1:2]) assert torch.equal(micro_batch[1], y[1:2]) - micro_batch = get_micro_batch({'x': x, 'y': y}, 0, 1) - assert torch.equal(micro_batch['x'], x[0:1]) - assert torch.equal(micro_batch['y'], y[0:1]) - micro_batch = get_micro_batch({'x': x, 'y': y}, 1, 1) - assert torch.equal(micro_batch['x'], x[1:2]) - assert torch.equal(micro_batch['y'], y[1:2]) + micro_batch = get_micro_batch({"x": x, "y": y}, 0, 1) + assert torch.equal(micro_batch["x"], x[0:1]) + assert torch.equal(micro_batch["y"], y[0:1]) + micro_batch = get_micro_batch({"x": x, "y": y}, 1, 1) + assert torch.equal(micro_batch["x"], x[1:2]) + assert torch.equal(micro_batch["y"], y[1:2]) def test_merge_batch(): @@ -42,6 +42,6 @@ def test_merge_batch(): merged = merge_batch([[x[0:1], y[0:1]], [x[1:2], y[1:2]]]) assert torch.equal(merged[0], x) assert torch.equal(merged[1], y) - merged = merge_batch([{'x': x[0:1], 'y': y[0:1]}, {'x': x[1:2], 'y': y[1:2]}]) - assert torch.equal(merged['x'], x) - assert torch.equal(merged['y'], y) + merged = merge_batch([{"x": x[0:1], "y": y[0:1]}, {"x": x[1:2], "y": y[1:2]}]) + assert torch.equal(merged["x"], x) + assert torch.equal(merged["y"], y) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index 6e0cd1998c11..ed8284b3e64c 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -64,7 +64,7 @@ def check_stage_manager(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") check_stage_manager() @@ -74,5 +74,5 @@ def test_pipeline_stage_manager(): spawn(run_dist, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_stage_manager() diff --git a/tests/test_shardformer/test_layer/test_dist_crossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py index 72e6e5cf26ed..277a5b2bb4be 100644 --- a/tests/test_shardformer/test_layer/test_dist_crossentropy.py +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -7,12 +7,14 @@ from colossalai.shardformer.layer import cross_entropy_1d from colossalai.testing import rerun_if_address_is_in_use, spawn -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) +CONFIG = dict( + parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")), +) def check_dist_crossentropy(rank, world_size, port, ignore_index): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") # prepare data pred = torch.randn(2, 4, 8, requires_grad=True) @@ -25,10 +27,11 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index): org_loss = F.cross_entropy(org_pred, org_labels) dist_pred = pred.chunk(world_size, -1)[rank] - dist_loss = cross_entropy_1d(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) + dist_loss = cross_entropy_1d(dist_pred.to("cuda"), labels.to("cuda"), ignore_index=ignore_index) - assert torch.allclose(org_loss, dist_loss, - atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" + assert torch.allclose( + org_loss, dist_loss, atol=1e-5 + ), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" @pytest.mark.dist @@ -38,5 +41,5 @@ def test_dist_crossentropy(): spawn(check_dist_crossentropy, 2, ignore_index=ignore_index) -if __name__ == '__main__': +if __name__ == "__main__": test_dist_crossentropy() diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py index 332e377110a4..576620e6c7f3 100644 --- a/tests/test_shardformer/test_layer/test_dropout.py +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -56,7 +56,7 @@ def check_dropout_replicated_input(): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_dropout_parallel_input() check_dropout_replicated_input() @@ -66,5 +66,5 @@ def test_dropout(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_dropout() diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index d62dba7ea92a..3dbbcd766bf4 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -11,7 +11,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -43,7 +43,7 @@ def check_embedding_1d(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_embedding_1d() @@ -52,5 +52,5 @@ def test_embedding_1d(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_embedding_1d() diff --git a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py index 4c0f884a7ed5..10ffdcd7138c 100644 --- a/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_gpt2_qkv_fused_linear_1d.py @@ -58,12 +58,9 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool) linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, - process_group=None, - gather_output=True, - seq_parallel=seq_parallel, - n_fused=3, - overlap=overlap) + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, n_fused=3, overlap=overlap + ) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -97,10 +94,9 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, - process_group=None, - parallel_input=False, - seq_parallel=seq_parallel) + linear_row = GPT2FusedLinearConv1D_Row.from_native_module( + linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + ) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) @@ -128,16 +124,16 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool): assert_close(target_grad, linear_row.weight.grad) -@parameterize('lazy_init', [False, True]) -@parameterize('seq_parallel', [False, True]) -@parameterize('overlap', [True]) +@parameterize("lazy_init", [False, True]) +@parameterize("seq_parallel", [False, True]) +@parameterize("overlap", [True]) def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool): check_linear_conv_1d_col(lazy_init, seq_parallel, overlap) check_linear_conv_1d_row(lazy_init, seq_parallel) def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # test for linear conv check_gpt2_qkv_fused_linear_1d() @@ -148,5 +144,5 @@ def test_linearconv(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_linearconv() diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index f9c21b82a282..3eb3bb2e5b8d 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -10,7 +10,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_layernorm(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -41,7 +41,7 @@ def check_layernorm(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_layernorm() @@ -50,5 +50,5 @@ def test_layernorm(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_layernorm() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index e6d86d533ed6..5bacf1865c48 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -17,11 +17,9 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_col = Linear1D_Col.from_native_module(linear_copy, - process_group=None, - gather_output=True, - seq_parallel=seq_parallel, - overlap=overlap) + linear_col = Linear1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, seq_parallel=seq_parallel, overlap=overlap + ) # ensure that the parameters are distributed assert is_distributed_tensor(linear_col.weight) @@ -60,8 +58,11 @@ def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool): # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( - x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + target_unshard_gard = ( + x_for_unshard.grad + if seq_parallel is False + else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + ) assert_close(target_unshard_gard, x_for_shard.grad) @@ -71,10 +72,9 @@ def check_linear_1d_row(lazy_init: bool, seq_parallel: bool): linear = nn.Linear(32, 128).cuda() with ctx: linear_copy = nn.Linear(32, 128).cuda() - linear_row = Linear1D_Row.from_native_module(linear_copy, - process_group=None, - parallel_input=False, - seq_parallel=seq_parallel) + linear_row = Linear1D_Row.from_native_module( + linear_copy, process_group=None, parallel_input=False, seq_parallel=seq_parallel + ) assert linear_row.weight.shape == torch.Size([128, 16]) assert linear_row.bias.shape == torch.Size([128]) @@ -121,15 +121,12 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool with ctx: linear_1_copy = nn.Linear(32, 128).cuda() linear_2_copy = nn.Linear(128, 32).cuda() - linear_col = Linear1D_Col.from_native_module(linear_1_copy, - process_group=None, - gather_output=False, - seq_parallel=seq_parallel, - overlap=overlap) - linear_row = Linear1D_Row.from_native_module(linear_2_copy, - process_group=None, - parallel_input=True, - seq_parallel=seq_parallel) + linear_col = Linear1D_Col.from_native_module( + linear_1_copy, process_group=None, gather_output=False, seq_parallel=seq_parallel, overlap=overlap + ) + linear_row = Linear1D_Row.from_native_module( + linear_2_copy, process_group=None, parallel_input=True, seq_parallel=seq_parallel + ) linear_1.load_state_dict(linear_col.state_dict()) linear_col.load_state_dict(linear_1.state_dict()) @@ -161,14 +158,17 @@ def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool # check the input gradients assert x_for_shard.grad is not None assert x_for_unshard.grad is not None - target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk( - x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + target_unshard_gard = ( + x_for_unshard.grad + if seq_parallel is False + else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()] + ) assert_close(target_unshard_gard, x_for_shard.grad) -@parameterize('lazy_init', [False, True]) -@parameterize('seq_parallel', [False, True]) -@parameterize('overlap', [True]) +@parameterize("lazy_init", [False, True]) +@parameterize("seq_parallel", [False, True]) +@parameterize("overlap", [True]) def run_dist_linear_test(lazy_init, seq_parallel, overlap): check_linear_1d_col(lazy_init, seq_parallel, overlap) check_linear_1d_row(lazy_init, seq_parallel) @@ -176,7 +176,7 @@ def run_dist_linear_test(lazy_init, seq_parallel, overlap): def check_dist_linear(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_dist_linear_test() @@ -185,5 +185,5 @@ def test_linear(): spawn(check_dist_linear, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_linear() diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index b45cd172c3ca..b02d581810cd 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -53,16 +53,15 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_linear_conv_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() linear = Conv1D(192, 48).cuda() with ctx: linear_copy = Conv1D(192, 48).cuda() - linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy, - process_group=None, - gather_output=True, - n_fused=3) + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module( + linear_copy, process_group=None, gather_output=True, n_fused=3 + ) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -89,7 +88,7 @@ def check_linear_conv_1d_col(lazy_init: bool): assert_close(target_grad, linear_conv_col.weight.grad) -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_linear_conv_1d_row(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() @@ -124,7 +123,7 @@ def check_linear_conv_1d_row(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # test for linear conv check_linear_conv_1d_col() @@ -136,5 +135,5 @@ def test_linearconv(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_linearconv() diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index 6d2f087302d9..b23a44f2dffa 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -11,13 +11,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -@parameterize('lazy_init', [False, True]) +@parameterize("lazy_init", [False, True]) def check_vocab_embedding_1d(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() - embedding = nn.Embedding(128, 32).to('cuda') + embedding = nn.Embedding(128, 32).to("cuda") with ctx: - embedding_copy = nn.Embedding(128, 32).to('cuda') + embedding_copy = nn.Embedding(128, 32).to("cuda") dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) @@ -30,7 +30,7 @@ def check_vocab_embedding_1d(lazy_init: bool): dist_embedding_1d.load_state_dict(embedding.state_dict()) # check embedding correctness - x = torch.randint(0, 128, (4, 32)).to('cuda') + x = torch.randint(0, 128, (4, 32)).to("cuda") org_out = embedding(x) dist_out = dist_embedding_1d(x) assert_close(org_out, dist_out) @@ -45,7 +45,7 @@ def check_vocab_embedding_1d(lazy_init: bool): def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_vocab_embedding_1d() @@ -54,5 +54,5 @@ def test_vocab_embedding(): spawn(run_dist, nprocs=2) -if __name__ == '__main__': +if __name__ == "__main__": test_vocab_embedding() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index c9c6447a43f0..0a2b151d4274 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -22,13 +22,15 @@ from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor -def build_model(model_fn, - enable_fused_normalization=True, - enable_tensor_parallelism=True, - enable_flash_attention=False, - enable_jit_fused=False, - enable_sequence_parallelism=False, - use_lazy_init: bool = False): +def build_model( + model_fn, + enable_fused_normalization=True, + enable_tensor_parallelism=True, + enable_flash_attention=False, + enable_jit_fused=False, + enable_sequence_parallelism=False, + use_lazy_init: bool = False, +): # create new model ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: @@ -38,23 +40,27 @@ def build_model(model_fn, if use_lazy_init: ctx.materialize(org_model) # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - enable_flash_attention=enable_flash_attention, - enable_jit_fused=enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism) + shard_config = ShardConfig( + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + enable_flash_attention=enable_flash_attention, + enable_jit_fused=enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + ) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy) return org_model.cuda(), sharded_model.cuda() -def build_pipeline_model(model_fn, - stage_manager=None, - enable_fused_normalization=False, - enable_tensor_parallelism=False, - use_lazy_init: bool = False, - policy: Optional[Policy] = None): +def build_pipeline_model( + model_fn, + stage_manager=None, + enable_fused_normalization=False, + enable_tensor_parallelism=False, + use_lazy_init: bool = False, + policy: Optional[Policy] = None, +): ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: # create new model @@ -64,9 +70,11 @@ def build_pipeline_model(model_fn, ctx.materialize(org_model) # shard model - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - pipeline_stage_manager=stage_manager) + shard_config = ShardConfig( + enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism, + pipeline_stage_manager=stage_manager, + ) shard_former = ShardFormer(shard_config=shard_config) sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy) @@ -91,22 +99,21 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, return org_output, org_loss, shard_output, shard_loss -def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''): +def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""): org_sd = org_model.state_dict() shard_sd = sharded_model.state_dict() for k, v in org_sd.items(): - assert k in shard_sd, f'{name} {k} not in sharded model' + assert k in shard_sd, f"{name} {k} not in sharded model" shard_v = shard_sd[k] - assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}' - assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}' - assert torch.equal(v, shard_v), f'{name} {k} value mismatch' + assert v.shape == shard_v.shape, f"{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}" + assert v.dtype == shard_v.dtype, f"{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}" + assert torch.equal(v, shard_v), f"{name} {k} value mismatch" def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any]): - use_lazy_init = False - if 'use_lazy_init' in test_config: - use_lazy_init = test_config.pop('use_lazy_init') + if "use_lazy_init" in test_config: + use_lazy_init = test_config.pop("use_lazy_init") ctx = LazyInitContext() if use_lazy_init else nullcontext() with ctx: @@ -127,9 +134,15 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster -def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Module, sharded_optimizer: Optimizer, - data_gen_fn: Callable, output_transform_fn: Callable, criterion: Callable, - booster: Booster): +def run_forward_backward_with_hybrid_plugin( + org_model: Module, + sharded_model: Module, + sharded_optimizer: Optimizer, + data_gen_fn: Callable, + output_transform_fn: Callable, + criterion: Callable, + booster: Booster, +): org_model.cuda() sharded_model.cuda() @@ -141,10 +154,10 @@ def _criterion(outputs, inputs): data = data_gen_fn() if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: - seq_len = data['input_ids'].shape[-1] + seq_len = data["input_ids"].shape[-1] lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) times = lcm // seq_len - input_shape = data['input_ids'].shape + input_shape = data["input_ids"].shape for k, v in data.items(): if v.shape == input_shape: data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) @@ -152,19 +165,16 @@ def _criterion(outputs, inputs): sharded_model.train() if booster.plugin.stage_manager is not None: for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) + data[k] = v.to("cuda").repeat(*new_shape) data_iter = iter([data]) - sharded_output = booster.execute_pipeline(data_iter, - sharded_model, - _criterion, - sharded_optimizer, - return_loss=True, - return_outputs=True) - sharded_loss = sharded_output['loss'] + sharded_output = booster.execute_pipeline( + data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True + ) + sharded_loss = sharded_output["loss"] else: data = {k: v.cuda() for k, v in data.items()} sharded_output = sharded_model(**data) @@ -182,45 +192,49 @@ def _criterion(outputs, inputs): return org_loss, org_output, sharded_loss, sharded_output -def check_output_hidden_state(org_output: Tensor, - sharded_output: Tensor, - stage_manager: Optional[PipelineStageManager] = None, - atol: float = 1e-5, - rtol: float = 1e-3, - dim: int = 0): - +def check_output_hidden_state( + org_output: Tensor, + sharded_output: Tensor, + stage_manager: Optional[PipelineStageManager] = None, + atol: float = 1e-5, + rtol: float = 1e-3, + dim: int = 0, +): org_hidden_state = org_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = sharded_output['outputs']['last_hidden_state'] + sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] else: sharded_hidden_state = sharded_output.last_hidden_state - assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ - f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + assert torch.allclose( + org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol + ), f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): - assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol), \ - f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" - - -def check_weight(org_model: Module, - sharded_model: Module, - layer_suffix: List[str], - tp_group: Optional[ProcessGroup] = None, - dim: int = 0, - atol: float = 1e-5, - rtol: float = 1e-3, - verbose: bool = False): - + assert torch.allclose( + org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol + ), f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + + +def check_weight( + org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: Optional[ProcessGroup] = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, +): for suffix in layer_suffix: org_weight = getattr_(org_model, suffix).weight sharded_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ - torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group)) + torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) ] dist.all_gather(sharded_weight_list, sharded_weight, tp_group) sharded_weight = torch.cat(sharded_weight_list, dim=dim) @@ -228,33 +242,35 @@ def check_weight(org_model: Module, if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") - assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \ - f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" - - -def get_grad_tensors_for_check(org_model: Module, - sharded_model: Module, - layer_suffix: List[str], - tp_group: ProcessGroup = None, - dim: int = 0, - atol: float = 1e-5, - rtol: float = 1e-3, - verbose: bool = False, - name: str = None): - + assert torch.allclose( + org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol + ), f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + + +def get_grad_tensors_for_check( + org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, + name: str = None, +): grad_to_check = {} for suffix in layer_suffix: org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] + shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim) # embedding may be resized when using tensor parallel if shard_grad.shape[0] > org_grad.shape[0]: - shard_grad = shard_grad[:org_grad.shape[0], :] + shard_grad = shard_grad[: org_grad.shape[0], :] if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") @@ -262,33 +278,35 @@ def get_grad_tensors_for_check(org_model: Module, "org_grad": org_grad.float(), "shard_grad": shard_grad.float(), "rtol": rtol, - "atol": atol + "atol": atol, } return grad_to_check # used by sam/blip2 -def check_grad(org_model: Module, - sharded_model: Module, - layer_suffix: List[str], - tp_group: ProcessGroup = None, - dim: int = 0, - atol: float = 1e-5, - rtol: float = 1e-3, - verbose: bool = False): +def check_grad( + org_model: Module, + sharded_model: Module, + layer_suffix: List[str], + tp_group: ProcessGroup = None, + dim: int = 0, + atol: float = 1e-5, + rtol: float = 1e-3, + verbose: bool = False, +): for suffix in layer_suffix: org_grad = getattr_(org_model, suffix).weight.grad shard_grad = getattr_(sharded_model, suffix).weight.grad shard_weight = getattr_(sharded_model, suffix).weight if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): - shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))] + shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))] dist.all_gather(shard_grad_list, shard_grad, tp_group) shard_grad = torch.cat(shard_grad_list, dim=dim) # embedding may be resized when using tensor parallel if shard_grad.shape[0] > org_grad.shape[0]: - shard_grad = shard_grad[:org_grad.shape[0], :] + shard_grad = shard_grad[: org_grad.shape[0], :] if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") @@ -297,9 +315,9 @@ def check_grad(org_model: Module, ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" -def unwrap_model(module: Module, - base_model_class_name: Optional[str] = None, - base_model_attribute_name: Optional[str] = None): +def unwrap_model( + module: Module, base_model_class_name: Optional[str] = None, base_model_attribute_name: Optional[str] = None +): if isinstance(module, HybridParallelModule): module = module.unwrap() if base_model_class_name is None: diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index c779e417052b..31fd58d06f77 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -1,6 +1,5 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,52 +20,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group - bert = unwrap_model(org_model, 'BertModel', 'bert') - sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert') + bert = unwrap_model(org_model, "BertModel", "bert") + sharded_bert = unwrap_model(sharded_model, "BertModel", "bert") - col_layer_for_check = ['encoder.layer[0].output.dense'] - row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense'] + col_layer_for_check = ["encoder.layer[0].output.dense"] + row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - col_layer_grads = get_grad_tensors_for_check(bert, - sharded_bert, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - row_layer_grads = get_grad_tensors_for_check(bert, - sharded_bert, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) + col_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -76,17 +59,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'BertModel': + if org_model.__class__.__name__ == "BertModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -99,53 +82,56 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': True, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": True, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_bert_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -155,31 +141,33 @@ def run_bert_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_bert_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() @@ -189,13 +177,13 @@ def run_bert_3d_test(test_config): def check_bert(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bert_test() def check_bert_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bert_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py index cd034d0c139a..02c15460ecb3 100644 --- a/tests/test_shardformer/test_model/test_shard_blip2.py +++ b/tests/test_shardformer/test_model/test_shard_blip2.py @@ -16,16 +16,18 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + org_output, org_loss, shard_output, shard_loss = run_forward( + org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn + ) + assert_hf_output_close(org_output, shard_output, ignore_keys=["past_key_values"]) # do backward org_loss.backward() shard_loss.backward() - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose( + org_loss, shard_loss, atol=1e-5 + ), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" # check grad @@ -34,26 +36,29 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check grad col_layer_for_check = [ - 'vision_model.encoder.layers[0].self_attn.qkv', 'qformer.encoder.layer[0].attention.attention.query', - 'language_model.model.decoder.layers[0].self_attn.k_proj' + "vision_model.encoder.layers[0].self_attn.qkv", + "qformer.encoder.layer[0].attention.attention.query", + "language_model.model.decoder.layers[0].self_attn.k_proj", ] row_layer_for_check = [ - 'vision_model.encoder.layers[0].self_attn.projection', 'qformer.encoder.layer[0].attention.output.dense', - 'language_model.model.decoder.layers[0].self_attn.out_proj' + "vision_model.encoder.layers[0].self_attn.projection", + "qformer.encoder.layer[0].attention.output.dense", + "language_model.model.decoder.layers[0].self_attn.out_proj", ] check_grad(blip2, sharded_blip2, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) check_grad(blip2, sharded_blip2, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) -@parameterize('enable_jit_fused', [True, False]) +@parameterize("enable_fused_normalization", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) +@parameterize("enable_flash_attention", [True, False]) +@parameterize("enable_jit_fused", [True, False]) def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): - sub_model_zoo = model_zoo.get_sub_registry('transformers_blip2') + sub_model_zoo = model_zoo.get_sub_registry("transformers_blip2") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention, enable_jit_fused) + org_model, sharded_model = build_model( + model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused + ) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() @@ -61,7 +66,7 @@ def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable def check_blip2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_blip2_test() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index c9ee690c86dc..7fe791db6d5e 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -20,53 +20,37 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - bloom = unwrap_model(org_model, 'BloomModel', 'transformer') - sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer') + bloom = unwrap_model(org_model, "BloomModel", "transformer") + sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer") - row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings'] - col_layer_for_check = ['h[0].self_attention.dense'] + row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"] + col_layer_for_check = ["h[0].self_attention.dense"] # Save gradient tensors for comparison between the original model and the sharded model. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-5 else: atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check(bloom, - sharded_bloom, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - col_layer_grads = get_grad_tensors_for_check(bloom, - sharded_bloom, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -76,17 +60,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'BloomModel': + if org_model.__class__.__name__ == "BloomModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -98,54 +82,51 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_bloom_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -155,29 +136,32 @@ def run_bloom_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_bloom_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -189,13 +173,13 @@ def run_bloom_3d_test(test_config): def check_bloom(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bloom_test() def check_bloom_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_bloom_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 48f651c727f4..bdf5b79fc498 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -1,6 +1,5 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,54 +20,52 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer') - shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer') + chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer") + shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer") - row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings'] - col_layer_for_check = ['encoder.layers[0].self_attention.dense'] + row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"] + col_layer_for_check = ["encoder.layers[0].self_attention.dense"] # Save gradient tensors for comparison between the original model and the sharded model. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-3 else: atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check(chatglm_model, - shard_chatglm_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - - col_layer_grads = get_grad_tensors_for_check(chatglm_model, - shard_chatglm_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + chatglm_model, + shard_chatglm_model, + row_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=0, + verbose=False, + ) + + col_layer_grads = get_grad_tensors_for_check( + chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -78,30 +75,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'ChatGLMModel': + if org_model.__class__.__name__ == "ChatGLMModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(chatglm_model, - shard_chatglm_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + check_weight( + chatglm_model, + shard_chatglm_model, + col_layer_for_check, + tp_group, + atol=atol, + rtol=rtol, + dim=1, + verbose=False, + ) # check grads check_all_grad_tensors(grads_to_check) @@ -110,45 +109,41 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_chatglm_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -157,29 +152,32 @@ def run_chatglm_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_chatglm_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') + sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -190,13 +188,13 @@ def run_chatglm_3d_test(test_config): def check_chatglm(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_chatglm_test() def check_chatglm_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_chatglm_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index c4cc3812dbfd..69a15166a54c 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,6 +1,5 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,53 +20,37 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer') - sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer') + gpt2 = unwrap_model(org_model, "GPT2Model", "transformer") + sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer") - col_layer_for_check = ['h[0].mlp.c_fc'] - row_layer_for_check = ['wte', 'h[0].mlp.c_proj'] + col_layer_for_check = ["h[0].mlp.c_fc"] + row_layer_for_check = ["wte", "h[0].mlp.c_proj"] # Save gradient tensors for comparison between the original model and the sharded model. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - col_layer_grads = get_grad_tensors_for_check(gpt2, - sharded_gpt2, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - row_layer_grads = get_grad_tensors_for_check(gpt2, - sharded_gpt2, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) + col_layer_grads = get_grad_tensors_for_check( + gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + row_layer_grads = get_grad_tensors_for_check( + gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -77,19 +60,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'GPT2Model': + if org_model.__class__.__name__ == "GPT2Model": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -102,63 +85,73 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp32", + }, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) @clear_cache_before_run() def run_gpt2_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -167,30 +160,33 @@ def run_gpt2_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) @clear_cache_before_run() def run_gpt2_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -201,13 +197,13 @@ def run_gpt2_3d_test(test_config): def check_gpt2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_gpt2_test() def check_gpt2_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_gpt2_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a60150e3cd72..f8f08e1d0075 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -2,7 +2,6 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,57 +20,41 @@ unwrap_model, ) -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - llama_model = unwrap_model(org_model, 'LlamaModel', 'model') - shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model') + llama_model = unwrap_model(org_model, "LlamaModel", "model") + shard_llama_model = unwrap_model(sharded_model, "LlamaModel", "model") - row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens'] - col_layer_for_check = ['layers[0].self_attn.o_proj'] + row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"] + col_layer_for_check = ["layers[0].self_attn.o_proj"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-4 else: atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check(llama_model, - shard_llama_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - col_layer_grads = get_grad_tensors_for_check(llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + llama_model, shard_llama_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -81,30 +64,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'LlamaModel': + if org_model.__class__.__name__ == "LlamaModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(llama_model, - shard_llama_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + check_weight( + llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) # check grads check_all_grad_tensors(grads_to_check) @@ -112,60 +90,64 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_llama_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -175,29 +157,32 @@ def run_llama_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_llama_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -209,13 +194,13 @@ def run_llama_3d_test(test_config): def check_llama(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_test() def check_llama_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_llama_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 3e74859ad1a8..d21ab264d8ab 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -2,7 +2,6 @@ import pytest import torch -from torch import distributed as dist import colossalai from colossalai.logging import disable_existing_loggers @@ -21,57 +20,41 @@ unwrap_model, ) -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - opt_model = unwrap_model(org_model, 'OPTModel', 'model') - shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model') + opt_model = unwrap_model(org_model, "OPTModel", "model") + shard_opt_model = unwrap_model(sharded_model, "OPTModel", "model") - row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens' - col_layer_for_check = ['decoder.layers[0].self_attn.out_proj'] + row_layer_for_check = ["decoder.layers[0].self_attn.q_proj", "decoder.embed_tokens"] # 'decoder.embed_tokens' + col_layer_for_check = ["decoder.layers[0].self_attn.out_proj"] # Save gradient tensors for comparison between the original model and the sharded model. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-3 else: atol, rtol = 4e-2, 4e-2 - row_layer_grads = get_grad_tensors_for_check(opt_model, - shard_opt_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - col_layer_grads = get_grad_tensors_for_check(opt_model, - shard_opt_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + opt_model, shard_opt_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -81,29 +64,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'OPTModel': + if org_model.__class__.__name__ == "OPTModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(opt_model, - shard_opt_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + check_weight( + opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) # check grads check_all_grad_tensors(grads_to_check) @@ -112,53 +90,51 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_opt_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -166,29 +142,32 @@ def run_opt_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_opt_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_opt") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -199,13 +178,13 @@ def run_opt_3d_test(test_config): def check_OPTModel(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_opt_test() def check_opt_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_opt_3d_test() @@ -223,6 +202,6 @@ def test_opt_3d(): spawn(check_opt_3d, 8) -if __name__ == '__main__': +if __name__ == "__main__": test_OPTModel() test_opt_3d() diff --git a/tests/test_shardformer/test_model/test_shard_sam.py b/tests/test_shardformer/test_model/test_shard_sam.py index 616104cd7828..a8d4cb635221 100644 --- a/tests/test_shardformer/test_model/test_shard_sam.py +++ b/tests/test_shardformer/test_model/test_shard_sam.py @@ -16,16 +16,18 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): # check forward - org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, - output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output, ignore_keys=['pred_masks']) + org_output, org_loss, shard_output, shard_loss = run_forward( + org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn + ) + assert_hf_output_close(org_output, shard_output, ignore_keys=["pred_masks"]) # do backward org_loss.backward() shard_loss.backward() - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose( + org_loss, shard_loss, atol=1e-5 + ), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" # check grad @@ -33,20 +35,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo sharded_sam = sharded_model # check grad - col_layer_for_check = ['mask_decoder.transformer.layers[0].self_attn.q_proj', 'vision_encoder.layers[0].mlp.lin1'] - row_layer_for_check = ['mask_decoder.transformer.layers[0].self_attn.out_proj', 'vision_encoder.layers[0].mlp.lin2'] + col_layer_for_check = ["mask_decoder.transformer.layers[0].self_attn.q_proj", "vision_encoder.layers[0].mlp.lin1"] + row_layer_for_check = ["mask_decoder.transformer.layers[0].self_attn.out_proj", "vision_encoder.layers[0].mlp.lin2"] check_grad(sam, sharded_sam, col_layer_for_check, atol=1e-5, rtol=1e-3, dim=0, verbose=False) check_grad(sam, sharded_sam, row_layer_for_check, atol=1e-3, rtol=1e-3, dim=1, verbose=False) -@parameterize('enable_fused_normalization', [True, False]) -@parameterize('enable_tensor_parallelism', [True, False]) -@parameterize('enable_flash_attention', [True, False]) +@parameterize("enable_fused_normalization", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) +@parameterize("enable_flash_attention", [True, False]) def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention): - sub_model_zoo = model_zoo.get_sub_registry('transformers_sam') + sub_model_zoo = model_zoo.get_sub_registry("transformers_sam") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism, - enable_flash_attention) + org_model, sharded_model = build_model( + model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention + ) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() @@ -54,7 +57,7 @@ def run_sam_test(enable_fused_normalization, enable_tensor_parallelism, enable_f def check_sam(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_sam_test() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 768cae0a6734..73f203d1f023 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,6 +1,5 @@ import pytest import torch -from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers @@ -21,19 +20,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group @@ -42,22 +35,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, t5 = unwrap_model(org_model) sharded_t5 = unwrap_model(sharded_model) - row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q'] + row_layer_for_check = ["shared", "encoder.block[0].layer[0].SelfAttention.q"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - row_layer_grads = get_grad_tensors_for_check(t5, - sharded_t5, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0) + row_layer_grads = get_grad_tensors_for_check( + t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0 + ) grads_to_check.update(row_layer_grads) # optimizer executes step @@ -66,18 +55,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ != 'T5ForConditionalGeneration': + if org_model.__class__.__name__ != "T5ForConditionalGeneration": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 5e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -90,67 +79,70 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", + }, + { + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) @clear_cache_before_run() def run_t5_test(test_config): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - # skip 4-stage pp test for t5_encoder - if test_config['pp_size'] > 2 and name == 'transformers_t5_encoder_model': + if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model": continue check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -160,29 +152,32 @@ def run_t5_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp16', - 'zero_stage': 1, - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, + ], +) def run_t5_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + sub_model_zoo = model_zoo.get_sub_registry("transformers_t5") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -193,13 +188,13 @@ def run_t5_3d_test(test_config): def check_t5(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_t5_test() def check_t5_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_t5_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 15db63bfd9da..1c934bd22340 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -20,54 +20,38 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwrap model - vit_model = unwrap_model(org_model, 'ViTModel', 'vit') - shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit') + vit_model = unwrap_model(org_model, "ViTModel", "vit") + shard_vit_model = unwrap_model(sharded_model, "ViTModel", "vit") # check grad - row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection'] - col_layer_for_check = ['encoder.layer[0].attention.output.dense'] + row_layer_for_check = ["encoder.layer[0].attention.attention.query", "embeddings.patch_embeddings.projection"] + col_layer_for_check = ["encoder.layer[0].attention.output.dense"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - row_layer_grads = get_grad_tensors_for_check(vit_model, - shard_vit_model, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) - col_layer_grads = get_grad_tensors_for_check(vit_model, - shard_vit_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + row_layer_grads = get_grad_tensors_for_check( + vit_model, shard_vit_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) + col_layer_grads = get_grad_tensors_for_check( + vit_model, shard_vit_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -77,29 +61,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'ViTModel': + if org_model.__class__.__name__ == "ViTModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights if stage_manager is None or stage_manager.is_first_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - check_weight(vit_model, - shard_vit_model, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) + check_weight( + vit_model, shard_vit_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) # check grads check_all_grad_tensors(grads_to_check) @@ -107,57 +86,54 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -#TODO: num_microbatch size = 2 inf loss -@parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp16', - 'initial_scale': 1, -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32' -}, { - 'tp_size': 2, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'zero_stage': 2, - 'precision': 'fp16', - 'initial_scale': 1 -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'zero_stage': 1, - 'precision': 'fp16', - 'initial_scale': 1 -}]) +# TODO: num_microbatch size = 2 inf loss +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + }, + {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, + { + "tp_size": 2, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "zero_stage": 2, + "precision": "fp16", + "initial_scale": 1, + }, + { + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": True, + "use_lazy_init": False, + "zero_stage": 1, + "precision": "fp16", + "initial_scale": 1, + }, + ], +) def run_vit_test(test_config): - # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models - sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + sub_model_zoo = model_zoo.get_sub_registry("transformers_vit") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -166,28 +142,31 @@ def run_vit_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + ], +) def run_vit_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + sub_model_zoo = model_zoo.get_sub_registry("transformers_vit") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -198,13 +177,13 @@ def run_vit_3d_test(test_config): def check_vit(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_vit_test() def check_vit_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_vit_3d_test() diff --git a/tests/test_shardformer/test_model/test_shard_whisper.py b/tests/test_shardformer/test_model/test_shard_whisper.py index d0c04c98f80a..f839bd84ab69 100644 --- a/tests/test_shardformer/test_model/test_shard_whisper.py +++ b/tests/test_shardformer/test_model/test_shard_whisper.py @@ -5,13 +5,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, @@ -26,24 +20,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): # check forward - org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ - build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) - - org_loss, org_output, sharded_loss, sharded_output = \ - run_forward_backward_with_hybrid_plugin( - org_model, - sharded_model, - sharded_optimizer, - data_gen_fn, - output_transform_fn, - criterion, - booster) + org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( + model_fn, loss_fn, test_config + ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group # unwarp the model - if org_model.__class__.__name__ == 'WhisperForConditionalGeneration': + if org_model.__class__.__name__ == "WhisperForConditionalGeneration": whisper = org_model.model sharded_whisper = sharded_model.unwrap().model else: @@ -51,41 +40,33 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_whisper = sharded_model.unwrap() # check grad - if org_model.__class__.__name__ == 'WhisperForAudioClassification': - col_layer_for_check = ['encoder.layers[0].self_attn.q_proj'] - row_layer_for_check = ['encoder.layers[0].self_attn.out_proj'] + if org_model.__class__.__name__ == "WhisperForAudioClassification": + col_layer_for_check = ["encoder.layers[0].self_attn.q_proj"] + row_layer_for_check = ["encoder.layers[0].self_attn.out_proj"] else: col_layer_for_check = [ - 'encoder.layers[0].self_attn.q_proj', - # 'decoder.layers[0].self_attn.q_proj' + "encoder.layers[0].self_attn.q_proj", + # 'decoder.layers[0].self_attn.q_proj' ] row_layer_for_check = [ - 'encoder.layers[0].self_attn.out_proj', - #'decoder.layers[0].self_attn.out_proj' + "encoder.layers[0].self_attn.out_proj", + #'decoder.layers[0].self_attn.out_proj' ] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - row_layer_grads = get_grad_tensors_for_check(whisper, - sharded_whisper, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1) - col_layer_grads = get_grad_tensors_for_check(whisper, - sharded_whisper, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0) + row_layer_grads = get_grad_tensors_for_check( + whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1 + ) + col_layer_grads = get_grad_tensors_for_check( + whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0 + ) grads_to_check.update(col_layer_grads) grads_to_check.update(row_layer_grads) @@ -95,38 +76,28 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 2e-4, 2e-4 else: atol, rtol = 5e-3, 5e-3 - if org_model.__class__.__name__ == 'WhisperModel': + if org_model.__class__.__name__ == "WhisperModel": check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights - if test_config['precision'] == 'fp32': + if test_config["precision"] == "fp32": atol, rtol = 1e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 if stage_manager is None or stage_manager.is_first_stage(): - check_weight(whisper, - sharded_whisper, - row_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=1, - verbose=False) - check_weight(whisper, - sharded_whisper, - col_layer_for_check, - tp_group, - atol=atol, - rtol=rtol, - dim=0, - verbose=False) + check_weight( + whisper, sharded_whisper, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False + ) + check_weight( + whisper, sharded_whisper, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False + ) # check grads check_all_grad_tensors(grads_to_check) @@ -134,49 +105,49 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -#TODO fix WhisperForConditionalGeneration enable jit fused operato +# TODO fix WhisperForConditionalGeneration enable jit fused operato # TODO(jianghai) fix fp16 @parameterize( - 'test_config', + "test_config", [ { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': True, - 'use_lazy_init': True, - 'precision': 'fp32', - 'initial_scale': 1, + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": True, + "use_lazy_init": True, + "precision": "fp32", + "initial_scale": 1, }, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, + "tp_size": 1, + "pp_size": 2, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, }, { - 'tp_size': 4, - 'pp_size': 1, - 'enable_all_optimization': True, - 'use_lazy_init': False, - 'precision': 'fp32', + "tp_size": 4, + "pp_size": 1, + "enable_all_optimization": True, + "use_lazy_init": False, + "precision": "fp32", }, { - 'tp_size': 1, - 'pp_size': 4, - 'num_microbatches': 4, - 'use_lazy_init': False, - 'precision': 'fp32', + "tp_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "use_lazy_init": False, + "precision": "fp32", }, - # whisper is not supported fp16 for now. - ]) + # whisper is not supported fp16 for now. + ], +) def run_whisper_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + sub_model_zoo = model_zoo.get_sub_registry("transformers_whisper") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - - if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification': + if test_config["pp_size"] > 2 and name == "transformers_whisper_for_audio_classification": continue check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -185,28 +156,31 @@ def run_whisper_test(test_config): torch.cuda.empty_cache() -@parameterize('test_config', [ - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, - { - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 2, - 'enable_all_optimization': False, - 'use_lazy_init': False, - 'precision': 'fp32', - 'initial_scale': 1, - }, -]) +@parameterize( + "test_config", + [ + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + { + "tp_size": 2, + "pp_size": 2, + "num_microbatches": 2, + "enable_all_optimization": False, + "use_lazy_init": False, + "precision": "fp32", + "initial_scale": 1, + }, + ], +) def run_whisper_3d_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper') + sub_model_zoo = model_zoo.get_sub_registry("transformers_whisper") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) @@ -217,13 +191,13 @@ def run_whisper_3d_test(test_config): def check_whisper(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_whisper_test() def check_whisper_3d(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_whisper_3d_test() diff --git a/tests/test_shardformer/test_shard_utils.py b/tests/test_shardformer/test_shard_utils.py index 220b8291c9c6..9739fad86d39 100644 --- a/tests/test_shardformer/test_shard_utils.py +++ b/tests/test_shardformer/test_shard_utils.py @@ -5,7 +5,6 @@ class Net(nn.Module): - def __init__(self) -> None: super().__init__() self.layers = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index 2b6933246298..f642a9dcada4 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -14,10 +14,9 @@ from tests.kit.model_zoo import model_zoo -@parameterize('lazy_init', [True, False]) +@parameterize("lazy_init", [True, False]) def check_shardformer_with_ddp(lazy_init: bool): - - sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") # create shardformer # ranks: [0, 1, 2, 3] @@ -72,7 +71,7 @@ def check_shardformer_with_ddp(lazy_init: bool): def run_dist(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") check_shardformer_with_ddp() diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py index 4a3199c1c53d..5e969b1aaf98 100644 --- a/tests/test_tensor/test_comm_spec_apply.py +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -29,10 +29,9 @@ def check_all_gather(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1 + ) sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm) assert sharded_tensor_to_comm.equal(tensor_to_check) @@ -101,11 +100,9 @@ def check_all_to_all(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, - sharding_spec, - gather_dim=0, - shard_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, sharding_spec, gather_dim=0, shard_dim=1, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -181,7 +178,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) assert rank == dist.get_rank() @@ -214,5 +211,5 @@ def test_comm_spec(): spawn(check_comm, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_comm_spec() diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index a1ea2946e6e7..6d1640b4f3dc 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -20,10 +20,9 @@ def check_all_gather(process_groups_dict, rank): tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda() # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - process_groups_dict, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, process_groups_dict, gather_dim=1, logical_process_axis=1 + ) sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm) assert sharded_tensor_to_comm.equal(tensor_to_check) @@ -38,10 +37,9 @@ def check_shard(process_groups_dict, rank): tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1) # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, - process_groups_dict, - shard_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, process_groups_dict, shard_dim=1, logical_process_axis=1 + ) tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard) if rank in (0, 2): @@ -79,11 +77,13 @@ def check_all_to_all(process_groups_dict, rank): tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda() # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, - process_groups_dict, - gather_dim=0, - shard_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD, + process_groups_dict, + gather_dim=0, + shard_dim=1, + logical_process_axis=0, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -124,7 +124,7 @@ def check_all_reduce_bwd(process_groups_dict, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) assert rank == dist.get_rank() @@ -157,5 +157,5 @@ def test_comm_spec(): spawn(check_comm, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_comm_spec() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 5a1aef79f332..33ae59d01550 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -8,7 +8,6 @@ class TestModel(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.linear_1 = torch.nn.Linear(in_features, out_features) @@ -22,9 +21,9 @@ def forward(self, x): def check_dtensor(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_model = TestModel(8, 8).to('cuda') - original_tensor = torch.rand(4, 8).to('cuda') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + test_model = TestModel(8, 8).to("cuda") + original_tensor = torch.rand(4, 8).to("cuda") compare_output = test_model(original_tensor) device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) @@ -39,7 +38,7 @@ def check_dtensor(rank, world_size, port): elif rank in (2, 3): assert d_tensor.equal(original_tensor.narrow(0, 2, 2)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") assert to_global(d_tensor).equal(original_tensor) output = test_model(d_tensor) @@ -48,7 +47,7 @@ def check_dtensor(rank, world_size, port): elif rank in (2, 3): assert output.equal(compare_output.narrow(0, 2, 2)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec) @@ -62,7 +61,7 @@ def check_dtensor(rank, world_size, port): elif rank == 3: assert d_tensor.equal(original_tensor.narrow(0, 3, 1)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) @@ -75,7 +74,7 @@ def check_dtensor(rank, world_size, port): elif rank == 3: assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1)) else: - raise ValueError(f'rank {rank} is not in the device mesh') + raise ValueError(f"rank {rank} is not in the device mesh") @rerun_if_address_is_in_use() @@ -84,5 +83,5 @@ def test_dtensor(): spawn(check_dtensor, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_dtensor() diff --git a/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py index 7fd1c3d90fc4..654a4438479a 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py +++ b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py @@ -26,9 +26,10 @@ def test_dtensor_sharding_spec(): assert dim_spec_list_0[2].dim_diff(dim_spec_list_1[2]) == 0 assert dim_spec_list_0[3].dim_diff(dim_spec_list_1[3]) == 0 - assert sharding_spec_0.spec_diff(sharding_spec_1) == \ - reduce(operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0) + assert sharding_spec_0.spec_diff(sharding_spec_1) == reduce( + operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0 + ) -if __name__ == '__main__': +if __name__ == "__main__": test_dtensor_sharding_spec() diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 5388fd901e09..4e65401bf7b4 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -20,7 +20,7 @@ def check_one_step_transform(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # [[0, 1], # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -34,10 +34,10 @@ def check_one_step_transform(rank, world_size, port): rst_dict = layout_converter.all_gather_transform_layouts(layout) - assert '[R, S1, R]' in [ + assert "[R, S1, R]" in [ str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() ] - assert '[S0, R, R]' in [ + assert "[S0, R, R]" in [ str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() ] @@ -50,13 +50,13 @@ def check_one_step_transform(rank, world_size, port): rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() ] - assert '[R, S1, S0]' in [ + assert "[R, S1, S0]" in [ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() ] @@ -69,20 +69,20 @@ def check_one_step_transform(rank, world_size, port): rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() ] - assert '[S0, S1, R]' in [ + assert "[S0, S1, R]" in [ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() ] def check_layout_converting(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) @@ -102,8 +102,8 @@ def check_layout_converting(rank, world_size, port): transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) # check transform path - transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) - assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + transform_path_str = "->".join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) + assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]" # check comm action sequence # all-gather(S01) -> S0 @@ -123,18 +123,18 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].logical_process_axis == 1 # checkout chached_spec_pairs_transform_path - assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path - assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence + assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path + assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout) - assert comm_cost['forward'] == comm_cost['backward'] - assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward']) + assert comm_cost["forward"] == comm_cost["backward"] + assert math.floor(comm_cost["total"]) == math.floor(comm_cost["forward"] + comm_cost["backward"]) def check_layout_converting_apply(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -173,5 +173,5 @@ def test_layout_converter(): spawn(check_layout_converting_apply, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_layout_converter() diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py index bd71bffccc70..7d6f8979dd0b 100644 --- a/tests/test_tensor/test_mix_gather.py +++ b/tests/test_tensor/test_mix_gather.py @@ -17,12 +17,13 @@ def check_mix_gather_S0S1(device_mesh, rank): f_target_pair = (f, [0]) b_target_pair = (b, [1]) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - tensor_slice = [4, 2] # (4, 2) + tensor_slice = [4, 2] # (4, 2) rank_slice = 4 f_start = (rank // rank_slice) * tensor_slice[0] b_start = (rank % rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) dim_partition_dict = {0: [0], 1: [1]} @@ -31,12 +32,14 @@ def check_mix_gather_S0S1(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -48,12 +51,13 @@ def check_two_all_gather_S0S1(device_mesh, rank): dim_partition_dict = {0: [0], 1: [1]} - tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) + tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) rank_slice = 4 f_start = (rank // rank_slice) * tensor_slice[0] b_start = (rank % rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) # DistSpec: # shard_sequence: S0,S1 @@ -61,10 +65,9 @@ def check_two_all_gather_S0S1(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -75,10 +78,9 @@ def check_two_all_gather_S0S1(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -95,8 +97,9 @@ def check_mix_gather_S1S0(device_mesh, rank): rank_slice = 4 f_start = (rank % rank_slice) * tensor_slice[0] b_start = (rank // rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) dim_partition_dict = {0: [1], 1: [0]} @@ -105,12 +108,14 @@ def check_mix_gather_S1S0(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -120,12 +125,13 @@ def check_two_all_gather_S1S0(device_mesh, rank): tensor_width = 8 tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() - tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) + tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) rank_slice = 4 f_start = (rank % rank_slice) * tensor_slice[0] b_start = (rank // rank_slice) * tensor_slice[1] - tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], - b_start:b_start + tensor_slice[1]].contiguous().cuda() + tensor_to_comm = ( + tensor_to_check[f_start : f_start + tensor_slice[0], b_start : b_start + tensor_slice[1]].contiguous().cuda() + ) dim_partition_dict = {0: [1], 1: [0]} @@ -135,10 +141,9 @@ def check_two_all_gather_S1S0(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -149,10 +154,9 @@ def check_two_all_gather_S1S0(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -165,7 +169,7 @@ def check_mix_gather_S01R(device_mesh, rank): f_target_pair = (f, [0, 1]) b_target_pair = (b, []) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda() + tensor_to_comm = tensor_to_check[rank : rank + 1, :].contiguous().cuda() dim_partition_dict = {0: [0, 1]} # DistSpec: @@ -173,12 +177,14 @@ def check_mix_gather_S01R(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -189,7 +195,7 @@ def check_two_all_gather_S01R(device_mesh, rank): tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() rank_stride = tensor_width // 8 - tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda() + tensor_to_comm = tensor_to_check[rank : rank + rank_stride, :].contiguous().cuda() dim_partition_dict = {0: [0, 1]} @@ -199,10 +205,9 @@ def check_two_all_gather_S01R(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -214,10 +219,9 @@ def check_two_all_gather_S01R(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=0, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=0, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -231,7 +235,7 @@ def check_mix_gather_RS01(device_mesh, rank): f_target_pair = (f, []) b_target_pair = (b, [0, 1]) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) - tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda() + tensor_to_comm = tensor_to_check[:, rank : rank + 1].contiguous().cuda() dim_partition_dict = {1: [0, 1]} # DistSpec: @@ -239,12 +243,14 @@ def check_mix_gather_RS01(device_mesh, rank): # device_mesh_shape: (2, 4) source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) - comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, - sharding_spec=source_spec, - gather_dim=gather_dim, - logical_process_axis=logical_process_axes, - forward_only=True, - mix_gather=True) + comm_spec = CommSpec( + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True, + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) assert tensor_to_comm.equal(tensor_to_check) @@ -255,7 +261,7 @@ def check_two_all_gather_RS01(device_mesh, rank): tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() rank_stride = tensor_width // 8 - tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda() + tensor_to_comm = tensor_to_check[:, rank : rank + rank_stride].contiguous().cuda() dim_partition_dict = {1: [0, 1]} @@ -265,10 +271,9 @@ def check_two_all_gather_RS01(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=1) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=1 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -280,10 +285,9 @@ def check_two_all_gather_RS01(device_mesh, rank): sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) - comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, - sharding_spec, - gather_dim=1, - logical_process_axis=0) + comm_spec = CommSpec( + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, sharding_spec, gather_dim=1, logical_process_axis=0 + ) tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) @@ -292,7 +296,7 @@ def check_two_all_gather_RS01(device_mesh, rank): def check_comm(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 8) assert rank == dist.get_rank() @@ -326,5 +330,5 @@ def test_mix_gather(): spawn(check_comm, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_mix_gather() diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 859eef051256..c51797912e6f 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -2,7 +2,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.tensor.sharding_spec import ShardingSpec physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) @@ -16,7 +16,6 @@ def test_one_step_transform(): - dim_partition_dict = {0: [0], 1: [1]} # DistSpec: # shard_sequence: S0,S1,R @@ -28,16 +27,14 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec: # shard_sequence: S0,R,R # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)} - rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, { - "forward": 0, - "backward": 0, - "total": 0 - }) + rst_dict = shape_consistency_manager.get_all_all_gather_spec( + sharding_spec, {"forward": 0, "backward": 0, "total": 0} + ) - assert '[R, S1, R]' in [ + assert "[R, S1, R]" in [ str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() ] - assert '[S0, R, R]' in [ + assert "[S0, R, R]" in [ str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() ] @@ -53,19 +50,17 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec: # shard_sequence: S0,R,S1 # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)} - rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, { - "forward": 0, - "backward": 0, - "total": 0 - }) + rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec( + sharding_spec_all2all, {"forward": 0, "backward": 0, "total": 0} + ) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() ] - assert '[R, S1, S0]' in [ + assert "[R, S1, S0]" in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() ] @@ -81,19 +76,17 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec: # shard_sequence: S0,R,S1 # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)} - rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, { - "forward": 0, - "backward": 0, - "total": 0 - }) + rst_dict_shard = shape_consistency_manager.get_all_shard_spec( + sharding_spec_shard, {"forward": 0, "backward": 0, "total": 0} + ) - assert '[S01, R, R]' in [ + assert "[S01, R, R]" in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() ] - assert '[S0, S1, R]' in [ + assert "[S0, S1, R]" in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() ] - assert '[S0, R, S1]' in [ + assert "[S0, R, S1]" in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() ] @@ -113,10 +106,11 @@ def test_shape_consistency(): sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( - sharding_spec_source, sharding_spec_target) + sharding_spec_source, sharding_spec_target + ) - transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path]) - assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + transform_path_str = "->".join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path]) + assert transform_path_str == "[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]" # all-gather(S01) -> S0 assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD @@ -134,12 +128,15 @@ def test_shape_consistency(): assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].logical_process_axis == 1 - assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', - '[S01, R, R]')][0] == transform_path - assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', - '[S01, R, R]')][1] == comm_action_sequence + assert ( + shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][0] == transform_path + ) + assert ( + shape_consistency_manager.cached_spec_pairs_transform_path[("[R, S01, R]", "[S01, R, R]")][1] + == comm_action_sequence + ) -if __name__ == '__main__': +if __name__ == "__main__": test_one_step_transform() test_shape_consistency() diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py index b57952df401f..b2bc84edd87f 100644 --- a/tests/test_tensor/test_shape_consistency_apply.py +++ b/tests/test_tensor/test_shape_consistency_apply.py @@ -4,14 +4,14 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn def check_apply(rank, world_size, port): disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -72,5 +72,5 @@ def test_apply(): spawn(check_apply, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_apply() diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 5007c4141849..7730683bf525 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -1,7 +1,7 @@ import torch from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.tensor.sharding_spec import ShardingSpec def test_sharding_spec(): @@ -21,5 +21,5 @@ def test_sharding_spec(): assert str(sharding_spec.sharding_sequence) == "[S01, R, R]" -if __name__ == '__main__': +if __name__ == "__main__": test_sharding_spec() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index f775710c40c2..a5c465ba0b07 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -23,30 +23,30 @@ def attention_ref(q, k, v, attn_mask=None, causal=False): seqlen_q, seqlen_k = q.shape[1], k.shape[1] d = q.shape[-1] scale = 1.0 / math.sqrt(d) - scores = torch.einsum('bthd,bshd->bhts', q * scale, k) + scores = torch.einsum("bthd,bshd->bhts", q * scale, k) if attn_mask is not None: - scores.masked_fill_(rearrange(~attn_mask, 'b s -> b 1 1 s'), float('-inf')) + scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) if causal: causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) - scores.masked_fill_(causal_mask, float('-inf')) + scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) - output = torch.einsum('bhts,bshd->bthd', attention, v) + output = torch.einsum("bhts,bshd->bthd", attention, v) output = rearrange(output, "b s h d -> b s (h d)") # Modify the data at the positions of the mask to 0 if attn_mask is not None: - output.masked_fill_(rearrange(~attn_mask, 'b s -> b s 1'), 0.0) + output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0) return output.to(dtype=dtype_og) @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('proj_shape', [(6, 8, 4, 16)]) -@parameterize('dtype', DTYPE) -@parameterize('dropout', [0.0]) +@parameterize("proj_shape", [(6, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) def test_attention_gpt(proj_shape, dtype, dropout): (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD @@ -78,9 +78,9 @@ def test_attention_gpt(proj_shape, dtype, dropout): @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('proj_shape', [(6, 8, 4, 16)]) -@parameterize('dtype', DTYPE) -@parameterize('dropout', [0.0]) +@parameterize("proj_shape", [(6, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) def test_attention_bert(proj_shape, dtype, dropout): (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD @@ -111,9 +111,9 @@ def test_attention_bert(proj_shape, dtype, dropout): @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('proj_shape', [(6, 8, 4, 16)]) -@parameterize('dtype', DTYPE) -@parameterize('dropout', [0.0]) +@parameterize("proj_shape", [(6, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) def test_attention_no_mask(proj_shape, dtype, dropout): (B, S, H, D_HEAD) = proj_shape D = H * D_HEAD @@ -141,9 +141,9 @@ def test_attention_no_mask(proj_shape, dtype, dropout): @pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available") @clear_cache_before_run() -@parameterize('proj_shape', [(6, 24, 8, 4, 16)]) -@parameterize('dtype', DTYPE) -@parameterize('dropout', [0.0]) +@parameterize("proj_shape", [(6, 24, 8, 4, 16)]) +@parameterize("dtype", DTYPE) +@parameterize("dropout", [0.0]) def test_cross_attention(proj_shape, dtype, dropout): (B, S, T, H, D_HEAD) = proj_shape D = H * D_HEAD diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py index f05ccfdbd41b..879eeccde3b4 100644 --- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py +++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py @@ -12,54 +12,53 @@ CPU_MEM = {True: {True: 0, False: 0}, False: {True: 512, False: 0}} -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) +@parameterize("keep_gathered", [True, False]) +@parameterize("pin_memory", [True, False]) def exam_chunk_memory(keep_gathered, pin_memory): - params = [ColoTensor(torch.rand(8, 8)) for _ in range(3)] config = {2: dict(chunk_size=128, keep_gathered=keep_gathered)} chunk_manager = ChunkManager(config) - assert chunk_manager.total_mem['cpu'] == 0 - assert chunk_manager.total_mem['cuda'] == 0 + assert chunk_manager.total_mem["cpu"] == 0 + assert chunk_manager.total_mem["cuda"] == 0 process_group = _get_default_group() for p in params: - chunk_manager.register_tensor(p, 'param', 2, process_group, pin_memory=pin_memory) + chunk_manager.register_tensor(p, "param", 2, process_group, pin_memory=pin_memory) chunk_manager.close_all_groups() - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_0[keep_gathered] chunks = chunk_manager.get_chunks(params) for chunk in chunks: chunk_manager.access_chunk(chunk) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[True] + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_0[True] for chunk in chunks: chunk_manager.release_chunk(chunk) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][pin_memory] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_0[keep_gathered] + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][pin_memory] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_0[keep_gathered] for chunk in chunks: - chunk_manager.move_chunk(chunk, torch.device('cpu')) - assert chunk_manager.total_mem['cpu'] == CPU_MEM[keep_gathered][True] - assert chunk_manager.total_mem['cuda'] == CUDA_MEM_1[keep_gathered] + chunk_manager.move_chunk(chunk, torch.device("cpu")) + assert chunk_manager.total_mem["cpu"] == CPU_MEM[keep_gathered][True] + assert chunk_manager.total_mem["cuda"] == CUDA_MEM_1[keep_gathered] def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_chunk_memory() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_chunk_manager(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_chunk_manager(2) diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index cc598ee60361..a31c888e966d 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -31,26 +31,28 @@ def check_equal(param, param_cp): return torch.equal(temp, param_cp.data) -@parameterize('init_device', [None, torch.device('cpu')]) -@parameterize('keep_gathered', [True, False]) -@parameterize('pin_memory', [True, False]) +@parameterize("init_device", [None, torch.device("cpu")]) +@parameterize("keep_gathered", [True, False]) +@parameterize("pin_memory", [True, False]) def exam_chunk_basic(init_device, keep_gathered, pin_memory): world_size = torch.distributed.get_world_size() pg = _get_default_group() - my_chunk = Chunk(chunk_size=1024, - process_group=pg, - dtype=torch.float32, - init_device=init_device, - cpu_shard_init=True, - keep_gathered=keep_gathered, - pin_memory=pin_memory) + my_chunk = Chunk( + chunk_size=1024, + process_group=pg, + dtype=torch.float32, + init_device=init_device, + cpu_shard_init=True, + keep_gathered=keep_gathered, + pin_memory=pin_memory, + ) param_list = [] param_cp_list = [] - add_param(param_list, param_cp_list, 8, 8, 8, device='cuda') + add_param(param_list, param_cp_list, 8, 8, 8, device="cuda") add_param(param_list, param_cp_list, 4, 4) - add_param(param_list, param_cp_list, 4, 8, 2, device='cuda') + add_param(param_list, param_cp_list, 4, 8, 2, device="cuda") add_param(param_list, param_cp_list, 1, 1, 5) for param in param_list: @@ -62,12 +64,12 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): if keep_gathered is False: assert my_chunk.cpu_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cpu' + assert my_chunk.device_type == "cpu" assert my_chunk.can_move my_chunk.shard_move(get_current_device()) else: assert my_chunk.cuda_global_chunk.size(0) == 1024 - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" assert not my_chunk.can_move assert dist_sum(my_chunk.valid_end) == my_chunk.utilized_size @@ -75,7 +77,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): assert not flag, "has_inf_or_nan is {}".format(flag) my_chunk.access_chunk() - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" for param, param_cp in zip(param_list, param_cp_list): check_equal(param, param_cp) @@ -97,25 +99,25 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory): if keep_gathered is False: assert my_chunk.cuda_shard.size(0) == 1024 // world_size - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" assert my_chunk.can_move else: assert my_chunk.cuda_global_chunk.size(0) == 1024 - assert my_chunk.device_type == 'cuda' + assert my_chunk.device_type == "cuda" assert not my_chunk.can_move def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_chunk_basic() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2, 4]) +@pytest.mark.parametrize("world_size", [1, 2, 4]) @rerun_if_address_is_in_use() def test_chunk_function(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_chunk_function(4) diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index fabdd6072c31..94e70040019c 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -16,21 +16,10 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half - { - 'placement_policy': 'auto' - } + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, ] @@ -41,14 +30,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): for chunk in chunk_list: chunk_manager.access_chunk(chunk) - for (p0, p1) in zip(model.parameters(), torch_model.parameters()): + for p0, p1 in zip(model.parameters(), torch_model.parameters()): assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('keep_gather', [False, True]) -@parameterize('model_name', ['gpt2', 'bert', 'albert']) -@parameterize('use_grad_checkpoint', [False, True]) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gather", [False, True]) +@parameterize("model_name", ["gpt2", "bert", "albert"]) +@parameterize("use_grad_checkpoint", [False, True]) def exam_gpt_fwd_bwd( placement_config, keep_gather, @@ -69,14 +58,14 @@ def exam_gpt_fwd_bwd( world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gather + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gather model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) rank = dist.get_rank() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[rank]) @@ -105,16 +94,16 @@ def exam_gpt_fwd_bwd( def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_gpt_fwd_bwd() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_gpt(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_gpt(4) diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index 614a96ccdbcd..2fa2d50a6caa 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -14,10 +14,10 @@ # run gemini use the runtime memory tracer -@parameterize('placement_policy', ['auto']) -@parameterize('keep_gather', [False]) -@parameterize('model_name', ['repeated_computed_layers', 'bert', 'albert', 'gpt2']) -@parameterize('use_grad_checkpoint', [False, True]) +@parameterize("placement_policy", ["auto"]) +@parameterize("keep_gather", [False]) +@parameterize("model_name", ["repeated_computed_layers", "bert", "albert", "gpt2"]) +@parameterize("use_grad_checkpoint", [False, True]) def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -25,7 +25,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ model = model_builder(use_grad_checkpoint).cuda() - print(f'model_name {model_name}') + print(f"model_name {model_name}") runtime_mem_tracer = RuntimeMemTracer(model) for i, (input_ids, label) in enumerate(train_dataloader): if i > 0: @@ -37,17 +37,17 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) memstats = runtime_mem_tracer.memstats() runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list - print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) - print('runtime tracer: ', runtime_tracer_non_model_data) + print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data)) + print("runtime tracer: ", runtime_tracer_non_model_data) print([memstats.param_used_step(p) for p in model.parameters()]) - if model_name == 'repeated_computed_layers': + if model_name == "repeated_computed_layers": for idx, p in enumerate(model.parameters()): step_list = memstats.param_used_step(p) if idx < 4: assert len(step_list) == 4 - if model_name == 'repeated_computed_layers': + if model_name == "repeated_computed_layers": for idx, p in enumerate(model.parameters()): step_list = memstats.param_used_step(p) if idx < 4: @@ -55,13 +55,11 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gather - model = GeminiDDP(model, - chunk_config_dict=config_dict, - placement_policy=placement_policy, - pin_memory=True, - memstats=memstats) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gather + model = GeminiDDP( + model, chunk_config_dict=config_dict, placement_policy=placement_policy, pin_memory=True, memstats=memstats + ) set_seed(dist.get_rank()) for i, (input_ids, label) in enumerate(train_dataloader): @@ -73,29 +71,30 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ input_ids, label = input_ids.cuda(), label.cuda() set_seed(42) - loss = run_fwd_bwd(model, input_ids, label, criterion, model) + run_fwd_bwd(model, input_ids, label, criterion, model) - gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list('cuda') + gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda") # print('gemini non model data:', gemini_non_model_data) - assert len(gemini_non_model_data) == len(runtime_tracer_non_model_data), \ - f'model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}' + assert len(gemini_non_model_data) == len( + runtime_tracer_non_model_data + ), f"model_name {model_name} {len(gemini_non_model_data)} vs {len(runtime_tracer_non_model_data)}" def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_gemini_use_rmt() @pytest.mark.skip("this is not used") @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_gemini_use_rmt(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_gemini_use_rmt(1) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 860d6efa899a..d8bcc555a15d 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -16,26 +16,24 @@ PLACEMENT_CONFIGS = [ { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.0, - 'offload_param_frac': 0.0 - }, # zero2 + "placement_policy": "static", + "shard_param_frac": 0.0, + "offload_optim_frac": 0.0, + "offload_param_frac": 0.0, + }, # zero2 { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 1.0, - 'offload_param_frac': 0.0 - }, # zero2-offload + "placement_policy": "static", + "shard_param_frac": 0.0, + "offload_optim_frac": 1.0, + "offload_param_frac": 0.0, + }, # zero2-offload { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.5, - 'offload_param_frac': 0.0 - }, # zero2-offload-half - { - 'placement_policy': 'auto' - } + "placement_policy": "static", + "shard_param_frac": 0.0, + "offload_optim_frac": 0.5, + "offload_param_frac": 0.0, + }, # zero2-offload-half + {"placement_policy": "auto"}, ] @@ -52,15 +50,15 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', ['gpt2']) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", ["gpt2"]) def exam_grad_clipping(placement_config, model_name: str): set_seed(1912) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) @@ -72,18 +70,16 @@ def exam_grad_clipping(placement_config, model_name: str): world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False - if placement_config['placement_policy'] != 'cuda': - init_device = torch.device('cpu') + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = False + if placement_config["placement_policy"] != "cuda": + init_device = torch.device("cpu") else: init_device = None - model = GeminiDDP(model, - chunk_config_dict=config_dict, - chunk_init_device=init_device, - pin_memory=True, - **placement_config) + model = GeminiDDP( + model, chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, **placement_config + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) @@ -106,6 +102,7 @@ def exam_grad_clipping(placement_config, model_name: str): assert_close(torch_loss, loss) import apex.amp as apex_amp + torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0) torch_optim.step() zero_optim.step() @@ -115,16 +112,16 @@ def exam_grad_clipping(placement_config, model_name: str): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_grad_clipping() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_grad_clip(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_clip(2) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 99ee08c1d7e7..2b2b246a9f54 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -18,21 +18,10 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half - { - 'placement_policy': 'auto' - } + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, ] @@ -52,8 +41,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): def multi_chunk_init(model: torch.nn.Module, placement_config: dict): world_size = dist.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = False model = GeminiDDP(model, config_dict, pin_memory=True, **placement_config) return model @@ -63,16 +52,16 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict): return model -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', ['gpt2']) -@parameterize('model_init_func', [single_chunk_init, multi_chunk_init]) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", ["gpt2"]) +@parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable): set_seed(19360226) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) @@ -121,16 +110,16 @@ def inference_iter(): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_inference() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_inference(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_inference(1) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 3454959199d2..b7c08392600f 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -16,50 +16,30 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 1.0 - }, # zero2-offload - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.5 - }, # zero2-offload-half - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0, - 'offload_optim_frac': 1.0, - 'offload_param_frac': 1.0 - }, # zero3-offload-all - { - 'placement_policy': 'auto' - } + "placement_policy": "static", + "shard_param_frac": 1.0, + "offload_optim_frac": 1.0, + "offload_param_frac": 1.0, + }, # zero3-offload-all + {"placement_policy": "auto"}, ] # this model is large enough to slice to chunks -TEST_MODELS = ['gpt2'] +TEST_MODELS = ["gpt2"] # these models are too small, all parameters in these models are compacted into one chunk -EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers'] +EXAMPLE_MODELS = ["albert", "beit", "bert", "hanging_param_model", "nested_model", "repeated_computed_layers"] # bfloat16 cannot represent them exactly BF16_IGNORED_KEYS = [ - 'albert.embeddings.word_embeddings.weight', - 'albert.embeddings.position_embeddings.weight', - 'masked_bias', + "albert.embeddings.word_embeddings.weight", + "albert.embeddings.position_embeddings.weight", + "masked_bias", ] @@ -78,23 +58,25 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty if dtype is torch.bfloat16: rtol, atol = 4e-3, 8e-3 # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert_close(value.float(), - temp_zero_value.float(), - rtol=rtol, - atol=atol, - msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}') - - -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', TEST_MODELS) -@parameterize('mixed_precision', [torch.half, torch.bfloat16]) + assert_close( + value.float(), + temp_zero_value.float(), + rtol=rtol, + atol=atol, + msg=lambda s: s + f"\n{key}\n{temp_zero_value.dtype}", + ) + + +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", TEST_MODELS) +@parameterize("mixed_precision", [torch.half, torch.bfloat16]) def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) @@ -106,8 +88,8 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = False + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = False model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) @@ -135,16 +117,16 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt check_param(model, torch_model, mixed_precision) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', EXAMPLE_MODELS) -@parameterize('mixed_precision', [torch.half, torch.bfloat16]) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", EXAMPLE_MODELS) +@parameterize("mixed_precision", [torch.half, torch.bfloat16]) def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) @@ -154,12 +136,14 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) - model = GeminiDDP(model, - chunk_init_device=get_current_device(), - search_range_m=1, - pin_memory=True, - mixed_precision=mixed_precision, - **placement_config) + model = GeminiDDP( + model, + chunk_init_device=get_current_device(), + search_range_m=1, + pin_memory=True, + mixed_precision=mixed_precision, + **placement_config, + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2) @@ -182,7 +166,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 + assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 zero_optim.step() torch_optim.step() @@ -192,17 +176,17 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_model_step() exam_tiny_example() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_optim(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_optim(1) diff --git a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py index 29bd61390523..8e0f6ae36c46 100644 --- a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py @@ -13,7 +13,7 @@ @pytest.mark.skip("this is not used") @clear_cache_before_run() def test_runtime_mem_tracer(): - test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] + test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"] for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -35,7 +35,7 @@ def test_runtime_mem_tracer(): for p1, p2 in zip(model_bk.parameters(), model.parameters()): torch.allclose(p1.to(torch.half), p2) - non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list('cuda') + non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list("cuda") cuda_non_model_data_list = np.array(non_model_data_list) / 1024**2 print("cuda_non_model_data_list", len(cuda_non_model_data_list)) print(non_model_data_list) @@ -46,9 +46,9 @@ def test_runtime_mem_tracer(): cnt2 = 0 for p in model.parameters(): cnt2 += 1 - assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}' + assert cnt2 == cnt1, f"visited param number {cnt1} vs real param number {cnt2}" del model -if __name__ == '__main__': +if __name__ == "__main__": test_runtime_mem_tracer() diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index 4c7f2ee6c132..e22e5ece42a5 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -11,19 +11,17 @@ def exam_search_chunk_size(): world_size = torch.distributed.get_world_size() - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() # make sure torch_model and model has the same parameter values model = model_builder() - config_dict, *_ = search_chunk_configuration(model, - search_range_m=1, - search_interval=16, - min_chunk_size_m=0, - filter_exlarge_params=True) + config_dict, *_ = search_chunk_configuration( + model, search_range_m=1, search_interval=16, min_chunk_size_m=0, filter_exlarge_params=True + ) for key in config_dict: - chunk_size = config_dict[key]['chunk_size'] + chunk_size = config_dict[key]["chunk_size"] if world_size == 1 or True: assert chunk_size == 31616 else: @@ -33,34 +31,36 @@ def exam_search_chunk_size(): def exam_chunk_manager(): world_size = torch.distributed.get_world_size() - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() sharded_ddp_model = model_builder() - chunk_manager = init_chunk_manager(sharded_ddp_model, - get_current_device(), - hidden_dim=16, - search_range_m=1, - min_chunk_size_m=0, - filter_exlarge_params=True, - strict_ddp_flag=True) + chunk_manager = init_chunk_manager( + sharded_ddp_model, + get_current_device(), + hidden_dim=16, + search_range_m=1, + min_chunk_size_m=0, + filter_exlarge_params=True, + strict_ddp_flag=True, + ) config_dict = chunk_manager.dp_degree_chunk_size_dict assert len(config_dict) == 1 assert config_dict[world_size] == 31616 def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_search_chunk_size() exam_chunk_manager() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_search(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_search(4) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 602e3ad3519d..3130440bd925 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -10,21 +10,10 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 1.0 - }, # zero3 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.5 - }, # zero3-half - { - 'placement_policy': 'auto' - } + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, ] @@ -35,9 +24,9 @@ def ignore_the_first_parameter(model: torch.nn.Module): return -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('keep_gathered', [True, False]) -@parameterize('model_name', ['gpt2', 'bert']) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [True, False]) +@parameterize("model_name", ["gpt2", "bert"]) def exam_state_dict(placement_config, keep_gathered, model_name: str): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -51,8 +40,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) model.train() @@ -65,9 +54,9 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('keep_gathered', [True, False]) -@parameterize('model_name', ['gpt2', 'bert']) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [True, False]) +@parameterize("model_name", ["gpt2", "bert"]) def exam_load_state_dict(placement_config, keep_gathered, model_name: str): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -76,12 +65,12 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): model = model_builder() set_seed(451) - torch_model = model_builder() # get a different model + torch_model = model_builder() # get a different model world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) @@ -95,8 +84,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('model_name', ['gpt2', 'bert']) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("model_name", ["gpt2", "bert"]) def exam_state_dict_shard(placement_config, model_name: str): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -122,18 +111,18 @@ def exam_state_dict_shard(placement_config, model_name: str): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() exam_load_state_dict() exam_state_dict_shard() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_zero_ddp(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_ddp(1) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 5f7b51510d58..8aa656b74cf9 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -11,32 +11,18 @@ from tests.components_to_test.registry import non_distributed_component_funcs PLACEMENT_CONFIGS = [ - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.0 - }, # zero2 - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 1.0 - }, # zero2-offload - { - 'placement_policy': 'static', - 'shard_param_frac': 0.0, - 'offload_optim_frac': 0.5 - }, # zero2-offload-half - { - 'placement_policy': 'auto' - } + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0}, # zero2-offload + {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5}, # zero2-offload-half + {"placement_policy": "auto"}, ] -@parameterize('placement_config', PLACEMENT_CONFIGS) -@parameterize('keep_gathered', [True, False]) +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [True, False]) def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable('gpt2') + get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model = model_builder() @@ -45,13 +31,13 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gathered + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) optimizer = HybridAdam(model.parameters()) - optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 + optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32 set_seed(dist.get_rank() * 3 + 128) model.train() @@ -67,8 +53,8 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): optim_state_dict = optim.state_dict() optim.load_state_dict(optim_state_dict) - new_state = optim.state_dict()['state'] - org_state = optim_state_dict['state'] + new_state = optim.state_dict()["state"] + org_state = optim_state_dict["state"] for k, v in org_state.items(): w = new_state[k] @@ -82,16 +68,16 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): def run_dist(rank, world_size, port): config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_zero_optim_state_dict() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_zero_optim(world_size): spawn(run_dist, world_size) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_optim(1) diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index f170f7cb83da..3c5baea138e0 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -14,7 +14,6 @@ class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(128, 256) @@ -36,16 +35,12 @@ def exam_zero_1_2_grad_acc(): # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) - zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, - overlap_communication=True, - initial_scale=32, - clip_grad_norm=1.0, - verbose=True) - zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=32, - clip_grad_norm=1.0) + zero1_optimizer = LowLevelZeroOptimizer( + zero1_optimizer, overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, verbose=True + ) + zero2_optimizer = LowLevelZeroOptimizer( + zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, clip_grad_norm=1.0 + ) # create data seed_all(2021 + local_rank) input_data1 = torch.randn(32, 128).cuda() @@ -91,10 +86,9 @@ def exam_zero_1_grad_acc(sync): # we only test stage 1 here # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results - zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, - overlap_communication=False, - reduce_bucket_size=262144, - clip_grad_norm=1.0) + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, clip_grad_norm=1.0 + ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -104,7 +98,6 @@ def exam_zero_1_grad_acc(sync): input_data2 = torch.randn(32, 128).cuda() def fwd_bwd_func(no_sync, cur_data, check_flag): - # zero1 fwd and bwd with conditional_context(zero_optimizer.no_sync(), no_sync): zero_output = zero_model(cur_data) @@ -135,7 +128,7 @@ def fwd_bwd_func(no_sync, cur_data, check_flag): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_grad_acc(sync=True) exam_zero_1_grad_acc(sync=False) @@ -147,5 +140,5 @@ def test_grad_accumulation(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_accumulation() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 9c4474aff5c3..ebda9f6f25c5 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -2,7 +2,6 @@ import pytest import torch -import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close @@ -14,7 +13,6 @@ class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(123, 253) @@ -74,14 +72,12 @@ def exam_zero_1_2(): # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) - zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, - overlap_communication=True, - initial_scale=128, - verbose=True) - zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=128) + zero1_optimizer = LowLevelZeroOptimizer( + zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True + ) + zero2_optimizer = LowLevelZeroOptimizer( + zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128 + ) # create data seed_all(2001 + local_rank) input_data = torch.randn(32, 123).cuda() @@ -109,7 +105,7 @@ def exam_zero_1_2(): assert torch.equal(z1p.data, z2p.data) -@parameterize('dtype', [torch.float16, torch.bfloat16]) +@parameterize("dtype", [torch.float16, torch.bfloat16]) def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): """ In this test, two pairs of model and optimizers are created. @@ -134,10 +130,9 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): # we only test stage 1 here # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results - zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, - overlap_communication=True, - initial_scale=1, - reduce_bucket_size=1024 * 1024) + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=1024 * 1024 + ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) @@ -178,7 +173,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_torch_ddp(world_size=world_size) exam_zero_1_2() @@ -190,5 +185,5 @@ def test_zero_1_2(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_1_2() diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py index ab811c6b4d3c..e9fc8598a62d 100644 --- a/tests/test_zero/test_low_level/test_zero_ckpt.py +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -2,19 +2,17 @@ import pytest import torch -import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer class MlpModel(nn.Module): - def __init__(self): super(MlpModel, self).__init__() self.linear1 = nn.Linear(12, 24) @@ -61,10 +59,9 @@ def exam_zero_1_torch_ddp_ckpt(): # we only test stage 1 here # the state dicts of stage 1 and stage 2 are the same - zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, - overlap_communication=True, - initial_scale=1, - reduce_bucket_size=262144) + zero_optimizer = LowLevelZeroOptimizer( + zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=262144 + ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -88,7 +85,7 @@ def exam_zero_1_torch_ddp_ckpt(): zero_state_dict = zero_optimizer.state_dict() # examine the original state dict - for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()): + for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()): for t_v, z_v in zip(torch_state.values(), zero_state.values()): loose_close(t_v, z_v) @@ -100,13 +97,13 @@ def exam_zero_1_torch_ddp_ckpt(): zero_state_dict = zero_optimizer.state_dict() # examine the loaded state dict - for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()): + for torch_state, zero_state in zip(torch_state_dict["state"].values(), zero_state_dict["state"].values()): for t_v, z_v in zip(torch_state.values(), zero_state.values()): loose_close(t_v, z_v) def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") exam_zero_1_torch_ddp_ckpt() @@ -117,5 +114,5 @@ def test_zero_ckpt(): spawn(run_dist, 2) -if __name__ == '__main__': +if __name__ == "__main__": test_zero_ckpt()