-
Notifications
You must be signed in to change notification settings - Fork 529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(fix)Make the weighted avarange fit for all kinds of systems #4593
base: devel
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughThis pull request modifies the testing functionalities within the DeePMD framework. In Changes
Suggested labels
Suggested reviewers
✨ Finishing Touches
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
source/tests/pt/test_weighted_avg.py (1)
70-101
: Well-structured test cases with good coverage!The test cases effectively validate different component combinations. Consider making the variable names more descriptive for better readability.
Consider renaming variables to be more descriptive:
- expected_mae_f = (2*3 +1*3 )/(3+3) + expected_force_mae = (2*3 + 1*3)/(3+3) - expected_mae_v = (3*5 +1*5 )/(5+5) + expected_virial_mae = (3*5 + 1*5)/(5+5)deepmd/entrypoints/test.py (3)
331-333
: Good addition of component flags!Consider using more descriptive variable names for better clarity.
- find_energy = test_data.get('find_energy') - find_force = test_data.get('find_force') - find_virial = test_data.get('find_virial') + has_energy_component = test_data.get('find_energy') + has_force_component = test_data.get('find_force') + has_virial_component = test_data.get('find_virial')
146-167
: Good selective error collection logic!Consider adding error handling for missing components.
Add error handling for missing components:
err_part = {} + if test_data.get('find_energy') is None: + log.warning("Energy component flag not found in test data") + if test_data.get('find_force') is None: + log.warning("Force component flag not found in test data") + if test_data.get('find_virial') is None: + log.warning("Virial component flag not found in test data") if find_energy == 1: err_part['mae_e'] = err['mae_e']
459-470
: Good conditional logging implementation!Consider adding debug logging for better troubleshooting.
Add debug logging:
+ log.debug(f"Processing system with energy={find_energy}, force={find_force}, virial={find_virial}") if find_force == 1: if not out_put_spin: log.info(f"Force MAE : {mae_f:e} eV/A")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/entrypoints/test.py
(12 hunks)source/tests/pt/test_weighted_avg.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (3)
source/tests/pt/test_weighted_avg.py (3)
7-31
: Well-structured implementation for handling different error metrics!The function effectively handles different combinations of energy, force, and virial metrics with clean conditional logic and proper error collection.
33-39
: Clean baseline implementation!The function provides a good reference point for comparing weighted averages with and without filtering.
43-67
: Comprehensive test coverage for energy-only metrics!The test case effectively validates:
- Correct weighted average calculations
- Proper handling of energy-only systems
- Edge cases with force and virial metrics
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
source/tests/pt/test_weighted_avg.py (2)
11-36
: Consider adding type hints and docstring.The function lacks type hints and documentation which would improve code maintainability and help users understand the expected input/output format.
Apply this diff to add type hints and docstring:
-def test(all_sys): +def test(all_sys: list[tuple[dict, int, int, int]]) -> dict[str, tuple[float, int]]: + """Calculate weighted average of errors with selective inclusion of metrics. + + Args: + all_sys: List of tuples containing (error_dict, find_energy, find_force, find_virial) + where error_dict contains the error metrics + + Returns: + Dictionary mapping error names to tuples of (error_value, sample_size) + """ err_coll = []
38-44
: Add type hints and docstring to test_ori function.Similar to the
test
function, this function would benefit from type hints and documentation.Apply this diff:
-def test_ori(all_sys): +def test_ori(all_sys: list[tuple[dict, int, int, int]]) -> dict[str, tuple[float, int]]: + """Calculate weighted average of all errors without selective inclusion. + + Args: + all_sys: List of tuples containing (error_dict, find_energy, find_force, find_virial) + where error_dict contains the error metrics + + Returns: + Dictionary mapping error names to tuples of (error_value, sample_size) + """ err_coll = []deepmd/entrypoints/test.py (2)
331-334
: Consider using dictionary get() with default values.The code uses
get()
without default values which could return None. Consider providing default values for safety.Apply this diff:
- find_energy = test_data.get("find_energy") - find_force = test_data.get("find_force") - find_virial = test_data.get("find_virial") + find_energy = test_data.get("find_energy", 0) + find_force = test_data.get("find_force", 0) + find_virial = test_data.get("find_virial", 0)
744-747
: Use f-strings instead of % operator for string formatting.The code uses the older % operator for string formatting. Consider using f-strings for better readability and maintainability.
Apply this diff:
- detail_path.with_suffix(".dos.out.%.d" % ii), + detail_path.with_suffix(f".dos.out.{ii:d}"), frame_output, - header="%s - %.d: data_dos pred_dos" % (system, ii), + header=f"{system} - {ii:d}: data_dos pred_dos",- detail_path.with_suffix(".ados.out.%.d" % ii), + detail_path.with_suffix(f".ados.out.{ii:d}"), frame_output, - header="%s - %.d: data_ados pred_ados" % (system, ii), + header=f"{system} - {ii:d}: data_ados pred_ados",Also applies to: 758-761
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/entrypoints/test.py
(10 hunks)source/tests/pt/test_weighted_avg.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (6)
source/tests/pt/test_weighted_avg.py (3)
47-113
: LGTM! Well-structured test case for energy-only scenario.The test case thoroughly validates the weighted average calculation for energy metrics, including proper assertions for both the main function and original implementation.
114-171
: LGTM! Comprehensive test case for energy and force metrics.The test case effectively validates the combined energy and force calculations, with appropriate assertions to verify the differences between the two implementations.
172-229
: LGTM! Complete test coverage for all components.The test case provides thorough validation of all metrics (energy, force, virial) with appropriate assertions.
deepmd/entrypoints/test.py (3)
137-168
: LGTM! Improved error handling with selective metric inclusion.The changes enhance error handling by:
- Using flags to determine which metrics to include
- Organizing errors into a separate dictionary
- Handling different force metric cases (regular vs spin)
459-468
: LGTM! Improved conditional logging for force metrics.The changes enhance logging by:
- Only logging when force data is present
- Handling different force metric types (regular vs spin)
469-469
: LGTM! Added condition for virial logging.The change ensures virial metrics are only logged when appropriate (PBC enabled and virial data present).
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
source/tests/pt/test_weighted_avg.py (3)
11-36
: Add type hints and docstring for better maintainability.The function needs documentation and type hints to improve maintainability. Also, consider using an enum or constants for flag values.
-def fake_test(all_sys): +from typing import List, Dict, Tuple, Union + +def fake_test(all_sys: List[Tuple[Dict[str, Tuple[float, float]], int, int, int]]) -> Dict[str, float]: + """Calculate weighted average of error metrics based on system components. + + Args: + all_sys: List of tuples containing (error_dict, has_energy, has_force, has_virial) + where error_dict contains metrics like mae_e, rmse_e etc. + + Returns: + Dictionary containing weighted averages of error metrics. + """Also, add error handling for missing force metrics:
if find_force == 1: + if not any(key in err for key in ['rmse_f', 'rmse_fr']): + raise KeyError("No force metrics found in error dictionary") if "rmse_f" in err:
38-44
: Add type hints and docstring for consistency.For consistency with
fake_test
, add type hints and documentation.-def fake_test_ori(all_sys): +def fake_test_ori(all_sys: List[Tuple[Dict[str, Tuple[float, float]], int, int, int]]) -> Dict[str, float]: + """Calculate weighted average of error metrics ignoring component flags. + + Used as a baseline for comparison with fake_test. + + Args: + all_sys: List of tuples containing (error_dict, has_energy, has_force, has_virial) + + Returns: + Dictionary containing weighted averages of all error metrics. + """
47-229
: Enhance test maintainability and coverage.While the test cases are comprehensive, consider these improvements:
- Move test data to class-level setup
- Add docstrings to test methods
- Add negative test cases
Example refactor:
class TestWeightedAverage(unittest.TestCase): + def setUp(self): + """Set up test data.""" + # Define common test data structure + self.base_system = { + "mae_e": (2, 2), + "mae_ea": (4, 2), + "rmse_e": (3, 2), + "rmse_ea": (5, 2), + "mae_f": (2, 3), + "rmse_f": (1, 3), + "mae_v": (3, 5), + "rmse_v": (3, 3), + } + def test_case1_energy_only(self): + """Test weighted average calculation with energy metrics only."""Add negative test cases:
def test_invalid_input(self): """Test handling of invalid input data.""" # Test with empty system list with self.assertRaises(ValueError): fake_test([]) # Test with invalid flag values invalid_sys = [(self.base_system, 2, 0, 0)] # Invalid energy flag with self.assertRaises(ValueError): fake_test(invalid_sys)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_weighted_avg.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (20)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (2)
source/tests/pt/test_weighted_avg.py (2)
1-8
: LGTM! Imports are well-organized.The license is properly specified and imports follow Python conventions.
231-232
: LGTM! Standard unittest main block.Follows Python conventions for unittest execution.
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🔭 Outside diff range comments (2)
deepmd/entrypoints/test.py (2)
562-574
:⚠️ Potential issueFix incorrect type hints.
The function's return type hint
tuple[list[np.ndarray], list[int]]
doesn't match the actual return type, which is a dictionary of tuples containing error metrics and their sizes.Update the type hint to match the actual return type:
-) -> tuple[list[np.ndarray], list[int]]: +) -> dict[str, tuple[float, int]]:Also, remove the commented-out flags from the return statements:
- } # find_energy,find_force,find_virial, + }Also applies to: 575-588
137-169
: 💡 Verification agent🧩 Analysis chain
Verify the impact of removing error filtering.
The commented-out code previously filtered errors based on flags (
find_energy
,find_force
,find_virial
). Removing this filtering could lead to inconsistencies in error reporting, especially when certain metrics are not available or relevant.Run the following script to verify error handling in test cases:
🏁 Script executed:
#!/bin/bash # Description: Check test cases for error handling patterns. # Search for test cases that rely on filtered errors rg -A 5 'test_ener.*find_' 'source/tests' # Search for error dictionary usage in test cases rg -A 5 'err\[.*(mae|rmse)_(e|f|v)' 'source/tests'Length of output: 4540
Inconsistent error reporting: Unfiltered error dictionary vs. filtered expectations in tests
The removal of the error filtering (i.e., the commented-out aggregation logic that created a filtered
err_part
) in thetest_ener
function means that the function now returns a complete error dictionary. However, test cases insource/tests/pt/test_weighted_avg.py
still expect the filtered error metrics (e.g., keys such as"mae_e"
,"rmse_e"
, etc., populated conditionally viafind_energy
,find_force
, andfind_virial
). This mismatch can lead to discrepancies in the averaged error computation or even test failures if additional keys or unexpected values are present.
- Action Item: Either update the test case expectations to work with the full error dictionary or reintroduce/adjust the filtering logic to ensure that the returned error dictionary strictly contains the expected keys and values.
🧹 Nitpick comments (1)
deepmd/entrypoints/test.py (1)
735-738
: Use f-strings for better readability.The code uses the
%
operator for string formatting, which is outdated in Python. Using f-strings would improve readability and maintainability.Apply this diff to update the string formatting:
- detail_path.with_suffix(".dos.out.%.d" % ii), - header="%s - %.d: data_dos pred_dos" % (system, ii), + detail_path.with_suffix(f".dos.out.{ii}"), + header=f"{system} - {ii}: data_dos pred_dos", - detail_path.with_suffix(".ados.out.%.d" % ii), - header="%s - %.d: data_ados pred_ados" % (system, ii), + detail_path.with_suffix(f".ados.out.{ii}"), + header=f"{system} - {ii}: data_ados pred_ados", - detail_path.with_suffix(".property.out.%.d" % ii), - header="%s - %.d: data_property pred_property" % (system, ii), + detail_path.with_suffix(f".property.out.{ii}"), + header=f"{system} - {ii}: data_property pred_property", - detail_path.with_suffix(".aproperty.out.%.d" % ii), - header="%s - %.d: data_aproperty pred_aproperty" % (system, ii), + detail_path.with_suffix(f".aproperty.out.{ii}"), + header=f"{system} - {ii}: data_aproperty pred_aproperty",Also applies to: 749-752, 891-894, 905-908
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/entrypoints/test.py
(12 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/entrypoints/test.py
332-332: Local variable find_energy
is assigned to but never used
Remove assignment to unused variable find_energy
(F841)
⏰ Context from checks skipped due to timeout of 90000ms (5)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
deepmd/entrypoints/test.py (1)
810-812
: LGTM! Improved code consistency.The changes to the property test function improve readability by:
- Using consistent variable names (
property
instead of mixed usage)- Using
task_dim
consistently for dimensionalityAlso applies to: 863-864, 868-869
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/entrypoints/test.py (2)
734-750
: Use f-strings for string formatting.For better readability and consistency with modern Python practices, consider using f-strings instead of the % formatting style.
Apply this diff to update the string formatting:
- detail_path.with_suffix(".dos.out.%.d" % ii), + detail_path.with_suffix(f".dos.out.{ii:d}"), frame_output, - header="%s - %.d: data_dos pred_dos" % (system, ii), + header=f"{system} - {ii:d}: data_dos pred_dos", append=append_detail, ) save_txt_file( - detail_path.with_suffix(".ados.out.%.d" % ii), + detail_path.with_suffix(f".ados.out.{ii:d}"), frame_output, - header="%s - %.d: data_ados pred_ados" % (system, ii), + header=f"{system} - {ii:d}: data_ados pred_ados", append=append_detail, )
890-906
: Use f-strings for string formatting.For better readability and consistency with modern Python practices, consider using f-strings instead of the % formatting style.
Apply this diff to update the string formatting:
- detail_path.with_suffix(".property.out.%.d" % ii), + detail_path.with_suffix(f".property.out.{ii:d}"), frame_output, - header="%s - %.d: data_property pred_property" % (system, ii), + header=f"{system} - {ii:d}: data_property pred_property", append=append_detail, ) save_txt_file( - detail_path.with_suffix(".aproperty.out.%.d" % ii), + detail_path.with_suffix(f".aproperty.out.{ii:d}"), frame_output, - header="%s - %.d: data_aproperty pred_aproperty" % (system, ii), + header=f"{system} - {ii:d}: data_aproperty pred_aproperty", append=append_detail, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/entrypoints/test.py
(12 hunks)
🔇 Additional comments (4)
deepmd/entrypoints/test.py (4)
137-145
: LGTM! Function signature update improves type safety.The updated function signature now correctly returns the error dictionary along with the flags, making the return type explicit and improving type safety.
146-167
: LGTM! Error filtering logic is now consistent.The error filtering logic has been improved to selectively include metrics based on the flags, addressing the inconsistency mentioned in past review comments.
809-811
: LGTM! Using task_dim improves flexibility.Using
dp.task_dim
for property dimensions makes the code more flexible and maintainable by relying on the model's configuration.
331-333
: Static analysis warning can be ignored.The flags
find_energy
,find_force
, andfind_virial
are now properly used in error filtering (lines 148-165) and returned from the function. The static analysis warning is outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/entrypoints/test.py (1)
744-747
: Modernize string formatting.Replace old-style % string formatting with f-strings for better readability and maintainability.
- detail_path.with_suffix(".dos.out.%.d" % ii), + detail_path.with_suffix(f".dos.out.{ii:d}"), - header="%s - %.d: data_dos pred_dos" % (system, ii), + header=f"{system} - {ii:d}: data_dos pred_dos",
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/entrypoints/test.py
(10 hunks)
🔇 Additional comments (4)
deepmd/entrypoints/test.py (4)
137-167
: LGTM! Improved error handling and type safety.The changes improve error handling by making the function return type explicit and filtering error metrics based on what was actually calculated. The structured error dictionary now only includes relevant metrics.
331-333
: LGTM! Resolved unused flag issues.The previously unused flags are now properly utilized for error filtering and conditional logging, addressing past review comments and static analysis warnings.
459-470
: LGTM! Consistent flag usage in logging.The changes implement consistent conditional logging based on the presence of force and virial calculations, with proper handling of spin calculations.
819-821
: LGTM! Consistent property testing implementation.The property testing implementation follows the same patterns as other test functions, with consistent dimension handling and error calculation.
Also applies to: 872-878
…md-kit into debug-weightedavg
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4593 +/- ##
==========================================
- Coverage 84.58% 84.56% -0.02%
==========================================
Files 680 680
Lines 64510 64529 +19
Branches 3539 3538 -1
==========================================
+ Hits 54563 54571 +8
- Misses 8807 8816 +9
- Partials 1140 1142 +2 ☔ View full report in Codecov by Sentry. |
for more information, see https://pre-commit.ci
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
deepmd/entrypoints/test.py (1)
385-387
:⚠️ Potential issueResolve inconsistency in flag usage.
The flags
find_energy
,find_force
, andfind_virial
were previously marked for removal but are now used in conditional logging.Choose one approach:
- Keep the flags and use them consistently throughout the code
- Remove the flags and always log all available metrics
If keeping the flags, ensure they are used consistently and document their purpose.
🧹 Nitpick comments (3)
deepmd/entrypoints/test.py (3)
257-290
: Add type hints to new functions.The new error handling functions lack type hints, which is inconsistent with the rest of the codebase. Also, consider grouping the parameters into a data class to improve maintainability.
Apply this diff to add type hints and improve code structure:
+from dataclasses import dataclass +from typing import Dict, Tuple, Union + +@dataclass +class EnergyTestData: + find_energy: int + find_force: int + find_virial: int + energy: np.ndarray + force: np.ndarray + virial: np.ndarray + mae_e: float + mae_ea: float + mae_f: float + mae_v: float + mae_va: float + rmse_e: float + rmse_ea: float + rmse_f: float + rmse_v: float + rmse_va: float -def test_ener_err( - find_energy, - find_force, - find_virial, - energy, - force, - virial, - mae_e, - mae_ea, - mae_f, - mae_v, - mae_va, - rmse_e, - rmse_ea, - rmse_f, - rmse_v, - rmse_va, -): +def test_ener_err(data: EnergyTestData) -> Dict[str, Tuple[float, int]]: err = {} - if find_energy == 1: - err["mae_e"] = (mae_e, energy.size) - err["mae_ea"] = (mae_ea, energy.size) - err["rmse_e"] = (rmse_e, energy.size) - err["rmse_ea"] = (rmse_ea, energy.size) + if data.find_energy == 1: + err["mae_e"] = (data.mae_e, data.energy.size) + err["mae_ea"] = (data.mae_ea, data.energy.size) + err["rmse_e"] = (data.rmse_e, data.energy.size) + err["rmse_ea"] = (data.rmse_ea, data.energy.size)Similar changes should be applied to
test_ener_err_ops
.Also applies to: 292-330
516-530
: Improve logging structure.The conditional logging blocks could be refactored for better readability and maintainability.
Consider using a logging helper function:
+def log_energy_metrics(mae_e: float, rmse_e: float, mae_ea: float, rmse_ea: float) -> None: + log.info(f"Energy MAE : {mae_e:e} eV") + log.info(f"Energy RMSE : {rmse_e:e} eV") + log.info(f"Energy MAE/Natoms : {mae_ea:e} eV") + log.info(f"Energy RMSE/Natoms : {rmse_ea:e} eV") - if find_energy == 1: - log.info(f"Energy MAE : {mae_e:e} eV") - log.info(f"Energy RMSE : {rmse_e:e} eV") - log.info(f"Energy MAE/Natoms : {mae_ea:e} eV") - log.info(f"Energy RMSE/Natoms : {rmse_ea:e} eV") + if find_energy == 1: + log_energy_metrics(mae_e, rmse_e, mae_ea, rmse_ea)
257-290
: Add unit tests for new error handling functions.The new error handling functions
test_ener_err
andtest_ener_err_ops
need test coverage.Would you like me to generate unit tests for these functions to ensure they handle various input scenarios correctly?
Also applies to: 292-330
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/entrypoints/test.py
(4 hunks)source/tests/pt/test_weighted_avg.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/test_weighted_avg.py
⏰ Context from checks skipped due to timeout of 90000ms (20)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test C++ (true)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
deepmd/entrypoints/test.py (1)
682-685
:⚠️ Potential issueFix unreachable code after return statement.
The Hessian-related code is unreachable as it appears after the return statements.
Move the Hessian handling before the return statements:
if dp.has_hessian: - dict_to_return["mae_h"] = (mae_h, hessian.size) - dict_to_return["rmse_h"] = (rmse_h, hessian.size) - return dict_to_return + err["mae_h"] = (mae_h, hessian.size) + err["rmse_h"] = (rmse_h, hessian.size) if not out_put_spin: return ( {🧰 Tools
🪛 Ruff (0.8.2)
683-683: Undefined name
dict_to_return
(F821)
684-684: Undefined name
dict_to_return
(F821)
685-685: Undefined name
dict_to_return
(F821)
…md-kit into debug-weightedavg
for more information, see https://pre-commit.ci
…md-kit into debug-weightedavg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🔭 Outside diff range comments (1)
deepmd/entrypoints/test.py (1)
687-690
:⚠️ Potential issueFix unreachable Hessian code.
This code block is unreachable as it appears after the return statements.
Move the Hessian handling before the return statements:
- if dp.has_hessian: - dict_to_return["mae_h"] = (mae_h, hessian.size) - dict_to_return["rmse_h"] = (rmse_h, hessian.size) - return dict_to_return + err = {} + if dp.has_hessian: + err["mae_h"] = (mae_h, hessian.size) + err["rmse_h"] = (rmse_h, hessian.size) + if not out_put_spin: + return test_ener_err(...) + else: + return test_ener_err_ops(...)🧰 Tools
🪛 Ruff (0.8.2)
688-688: Undefined name
dict_to_return
(F821)
689-689: Undefined name
dict_to_return
(F821)
690-690: Undefined name
dict_to_return
(F821)
🧹 Nitpick comments (1)
source/tests/pt/test_weighted_avg.py (1)
171-176
: Enhance test data validation.The test uses oversimplified mock data with constant values (
[1.0] * 10
,[0.1] * 10
, etc.). Consider adding test cases with:
- Different array sizes
- Edge cases (zeros, large numbers)
- Random data to catch potential numerical issues
- energy = np.array([1.0] * 10) - force = np.array([0.1] * 10) - virial = np.array([0.2] * 10) + # Test different array sizes and data patterns + sizes = [10, 100] + for size in sizes: + # Regular case + energy = np.random.random(size) + force = np.random.random(size) + virial = np.random.random(size) + # Edge cases + energy_edge = np.zeros(size) + force_edge = np.array([1e10] * size) + virial_edge = np.array([-1e10] * size)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/entrypoints/test.py
(4 hunks)source/tests/pt/test_weighted_avg.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (16)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (3)
deepmd/entrypoints/test.py (3)
257-289
: LGTM! Well-structured error handling.The function follows good practices:
- Single responsibility principle
- Conditional error reporting based on flags
- Includes data size with each error metric
292-331
: LGTM! Consistent error handling for magnetic systems.The function maintains consistency with
test_ener_err
while adding support for magnetic force components.
387-391
:✅ Verification successful
Verify flag usage in the codebase.
The flags are used for both error calculation and logging. Let's verify their usage across the codebase.
🏁 Script executed:
#!/bin/bash # Search for all usages of these flags echo "Searching for flag usage patterns..." rg "find_(energy|force|virial|force_real|force_mag)" --type py echo "Searching for test data flag definitions..." rg "test_data\.get\([\"']find_" --type pyLength of output: 13895
Flag usage verified across the codebase – no changes needed.
The search results confirm that the error and logging flags (“find_energy”, “find_force”, “find_virial”, “find_force_real”, and “find_force_mag”) are consistently used in multiple modules (e.g., in
deepmd/dpmodel/loss/ener.py
,deepmd/tf/loss/ener.py
, and various test files). In particular, the definitions indeepmd/entrypoints/test.py
are in line with their usage elsewhere in the codebase.
- Flag definitions via
test_data.get(...)
indeepmd/entrypoints/test.py
are used correctly in conditional checks.- Multiplicative error calculations and logging (e.g., in
deepmd/dpmodel/loss/ener.py
) consistently reference these flags.- Test files also verify these flags, confirming that both error evaluation and logging functionality follow the expected pattern.
Overall, the flag usage is coherent and properly integrated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/entrypoints/test.py
(4 hunks)source/tests/pt/test_weighted_avg.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/test_weighted_avg.py
⏰ Context from checks skipped due to timeout of 90000ms (12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (2)
deepmd/entrypoints/test.py (2)
387-392
: LGTM! Improved error handling for spin-polarized systems.The changes to
test_ener
function enhance error handling by:
- Adding separate flags for real and magnetic force components
- Implementing conditional logging based on flags
- Using the new error calculation functions consistently
Also applies to: 520-534, 645-685
687-690
:⚠️ Potential issueRemove unreachable code after return statements.
The Hessian-related code is unreachable as it appears after the return statements. This code should be moved before the return statements or removed if it's no longer needed.
🧰 Tools
🪛 Ruff (0.8.2)
688-688: Undefined name
dict_to_return
(F821)
689-689: Undefined name
dict_to_return
(F821)
690-690: Undefined name
dict_to_return
(F821)
for more information, see https://pre-commit.ci
…md-kit into debug-weightedavg
deepmd/entrypoints/test.py
Outdated
find_energy = test_data.get("find_energy") | ||
find_force = test_data.get("find_force") | ||
find_virial = test_data.get("find_virial") | ||
find_force_r = test_data.get("find_force_real") | ||
find_force_m = test_data.get("find_force_mag") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you checked the correct keys for force_r
and force_m
in spin model ?
See deepmd/pt/loss/ener_spin.py
for details.
They should be find_force
and find_force_mag
.
source/tests/pt/test_weighted_avg.py
Outdated
from deepmd.entrypoints.test import ( | ||
ener_err, | ||
ener_err_ops, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that you should test test_ener
instead of only testing ener_err
and ener_err_ops
!
deepmd/entrypoints/test.py
Outdated
@@ -254,6 +254,83 @@ def save_txt_file( | |||
np.savetxt(fp, data, header=header) | |||
|
|||
|
|||
def ener_err( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by the name ener_err
and ener_err_ops
??
They are not good names and can share almost code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Test the
test_ener
with a real dp model such as that insource/tests/pt/test_dp_test.py
. - Test model with/without spin.
- You should not use real numbers for reference(
mae_e, mae_ea, mae_f, mae_v, mae_va = 0.1, 0.1, 0.1, 0.1, 0.1
), you should use correct calculate process and get reference number (mae_e = reference_calcuate(dp, label)
).
The PR is not ready for review! Converted to draft. |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
deepmd/entrypoints/test.py (1)
483-484
:⚠️ Potential issueFix unreachable code after return statement.
The Hessian metrics are added to the dictionary after the return statement, making this code unreachable.
Move the Hessian metrics before the return statement:
if dp.has_hessian: log.info(f"Hessian MAE : {mae_h:e} eV/A^2") log.info(f"Hessian RMSE : {rmse_h:e} eV/A^2") + dict_to_return["mae_h"] = (mae_h, hessian.size) + dict_to_return["rmse_h"] = (rmse_h, hessian.size) if detail_file is not None: # ... file writing code ... return dict_to_return - dict_to_return["mae_h"] = (mae_h, hessian.size) - dict_to_return["rmse_h"] = (rmse_h, hessian.size)
🧹 Nitpick comments (1)
source/tests/pt/test_weighted_avg.py (1)
86-96
: Improve cleanup implementation.The cleanup implementation has several issues:
- No error handling for file operations
- Hardcoded file patterns
- No logging of cleanup failures
Apply this diff to improve the cleanup:
def tearDown(self) -> None: + """Clean up test artifacts.""" + patterns = [ + ("model*.pt", False), + (f"{self.detail_file}*", False), + ("lcurve.out", False), + (self.input_json, False), + ("stat_files", True), + ] + for f in os.listdir("."): - if f.startswith("model") and f.endswith(".pt"): - os.remove(f) - if f.startswith(self.detail_file): - os.remove(f) - if f in ["lcurve.out", self.input_json]: - os.remove(f) - if f in ["stat_files"]: - shutil.rmtree(f) + for pattern, is_dir in patterns: + try: + if (pattern.startswith(f) if "*" in pattern else f == pattern): + if is_dir: + shutil.rmtree(f) + else: + os.remove(f) + break + except OSError as e: + logging.warning(f"Failed to remove {f}: {e}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/entrypoints/test.py
(4 hunks)source/tests/pt/test_weighted_avg.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_weighted_avg.py
35-35: Use a context manager for opening files
(SIM115)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (1)
source/tests/pt/test_weighted_avg.py (1)
22-34
:⚠️ Potential issueFix test class implementation.
The test class has several issues:
- Doesn't inherit from
unittest.TestCase
- Test method doesn't follow the unittest naming convention (should start with "test_")
- Missing docstrings for class and method
Apply this diff to fix the test class:
-class DPTest: +class DPTest(unittest.TestCase): + """Test class for validating model trainer functionality.""" + def test_dp_test_1_frame(self) -> None: + """Test model predictions for a single frame with and without spin.""" trainer = get_trainer(deepcopy(self.config))Likely an incorrect or invalid review comment.
dict_to_return = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add validation for test flags.
The flags find_energy
, find_force
, find_virial
, and find_force_mag
are retrieved from test_data
but not validated. These flags should be integers (0 or 1).
Add validation after retrieving the flags:
dict_to_return = {}
find_energy = test_data.get("find_energy")
find_force = test_data.get("find_force")
find_virial = test_data.get("find_virial")
find_force_mag = test_data.get("find_force_mag")
+ # Validate flags
+ for flag_name, flag_value in [
+ ("find_energy", find_energy),
+ ("find_force", find_force),
+ ("find_virial", find_virial),
+ ("find_force_mag", find_force_mag),
+ ]:
+ if flag_value is not None and flag_value not in (0, 1):
+ raise ValueError(f"{flag_name} must be 0 or 1, got {flag_value}")
Also applies to: 312-315
dict_to_return["mae_e"] = (mae_e, energy.size) | ||
dict_to_return["mae_ea"] = (mae_e, energy.size) | ||
dict_to_return["mae_ea"] = (rmse_e, energy.size) | ||
dict_to_return["mae_ea"] = (rmse_e, energy.size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix duplicate dictionary keys.
The same key 'mae_ea' is used multiple times with different values, which will cause data loss.
Apply this diff to fix the duplicate keys:
- dict_to_return["mae_e"] = (mae_e, energy.size)
- dict_to_return["mae_ea"] = (mae_e, energy.size)
- dict_to_return["mae_ea"] = (rmse_e, energy.size)
- dict_to_return["mae_ea"] = (rmse_e, energy.size)
+ dict_to_return["mae_e"] = (mae_e, energy.size)
+ dict_to_return["mae_ea"] = (mae_ea, energy.size)
+ dict_to_return["rmse_e"] = (rmse_e, energy.size)
+ dict_to_return["rmse_ea"] = (rmse_ea, energy.size)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
dict_to_return["mae_e"] = (mae_e, energy.size) | |
dict_to_return["mae_ea"] = (mae_e, energy.size) | |
dict_to_return["mae_ea"] = (rmse_e, energy.size) | |
dict_to_return["mae_ea"] = (rmse_e, energy.size) | |
dict_to_return["mae_e"] = (mae_e, energy.size) | |
dict_to_return["mae_ea"] = (mae_ea, energy.size) | |
dict_to_return["rmse_e"] = (rmse_e, energy.size) | |
dict_to_return["rmse_ea"] = (rmse_ea, energy.size) |
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") | ||
torch.jit.save(model, tmp_model.name) | ||
dp_test( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Use context manager for temporary file.
The temporary file handling could lead to resource leaks and cleanup issues.
Apply this diff to use context managers:
- tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
- torch.jit.save(model, tmp_model.name)
- dp_test(
- model=tmp_model.name,
+ with tempfile.NamedTemporaryFile(suffix=".pth") as tmp_model:
+ torch.jit.save(model, tmp_model.name)
+ dp_test(
+ model=tmp_model.name,
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") | |
torch.jit.save(model, tmp_model.name) | |
dp_test( | |
with tempfile.NamedTemporaryFile(suffix=".pth") as tmp_model: | |
torch.jit.save(model, tmp_model.name) | |
dp_test( | |
model=tmp_model.name, | |
# ... (other parameters if present) | |
) |
🧰 Tools
🪛 Ruff (0.8.2)
35-35: Use a context manager for opening files
(SIM115)
Summary by CodeRabbit
Refactor
Tests