-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pt): train with energy Hessian #4169
base: devel
Are you sure you want to change the base?
Conversation
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
deepmd/pt/train/training.py (1)
1235-1237
: Consider adding error handling for Hessian parametersWhile the implementation is correct, consider adding validation for Hessian-specific parameters to provide better error messages.
Example validation:
if whether_hessian(loss_params): required_params = ['start_pref_h', 'limit_pref_h'] missing_params = [p for p in required_params if p not in loss_params] if missing_params: raise ValueError(f"Missing required Hessian parameters: {missing_params}") loss_params["starter_learning_rate"] = start_lr return EnergyHessianStdLoss(**loss_params)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/train/training.py
(5 hunks)
🔇 Additional comments (3)
deepmd/pt/train/training.py (3)
28-28
: LGTM: Import added correctly
The EnergyHessianStdLoss
import is properly placed with other loss-related imports.
1228-1231
: Improve the function name and implementation
The function should be renamed to is_hessian_required
for better clarity and the implementation could be more concise.
1283-1287
: Reconsider the architectural design of Hessian mode activation
The current implementation tightly couples loss parameters with model configuration, making it difficult to determine the model type from configuration alone.
Also applies to: 1289-1290, 1299-1300
Signed-off-by: Anchor Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/pt/train/training.py (2)
279-283
: Consider adding documentation for the loss_params parameterWhile the implementation is functional, it would be helpful to add documentation explaining why the loss parameters are used to determine the Hessian mode, especially since this might not be immediately intuitive to other developers.
1283-1287
: Consider adding validation for loss parametersWhile the implementation is functional, it would be beneficial to add validation for the loss parameters to ensure they contain the expected fields and values. This would help catch configuration errors early.
def get_model_for_wrapper( _model_params, resuming=False, _loss_params=None, ): + def validate_loss_params(params): + if params is None: + return + if not isinstance(params, dict): + raise ValueError("Loss parameters must be a dictionary") + if "type" in params and params["type"] not in ["ener", "dos", "ener_spin", "denoise", "tensor", "property"]: + raise ValueError(f"Unsupported loss type: {params['type']}") + + if _loss_params is not None: + if isinstance(_loss_params, dict) and "model_dict" in _model_params: + for key in _loss_params: + validate_loss_params(_loss_params[key]) + else: + validate_loss_params(_loss_params) + if "model_dict" not in _model_params:Also applies to: 1289-1290, 1299-1300
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/pt/train/training.py
(5 hunks)
🔇 Additional comments (4)
deepmd/pt/train/training.py (4)
28-28
: LGTM: Import of EnergyHessianStdLoss
The import is correctly placed with other loss-related imports.
268-277
: LGTM: Loss parameter handling for Hessian activation
The implementation correctly handles both single-task and multi-task scenarios by using a dictionary comprehension for multi-task loss parameters.
1228-1231
: Consider renaming the function for better clarity
The function could be renamed to is_hessian_required
for better clarity, and the logic could be simplified to a single return statement.
1235-1238
: LGTM: Hessian loss handling
The implementation correctly integrates Hessian loss handling while maintaining consistency with other loss types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
deepmd/utils/data.py (2)
666-671
: Document the Hessian index mapping logicThe index mapping for the Hessian is complex and would benefit from detailed documentation explaining:
- Why chunk_size is 3 (corresponds to x,y,z coordinates)
- How the mapping preserves the relationship between atomic coordinates
- The assumption about row-major ordering of the input Hessian
# get idx_map for hessian num_chunks, chunk_size = len(idx_map), 3 +# Create index mapping for 3D coordinates (x,y,z) of each atom +# num_chunks = number of atoms, chunk_size = 3 for x,y,z coordinates idx_map_hess = np.arange(num_chunks * chunk_size) # pylint: disable=no-explicit-dtype idx_map_hess = idx_map_hess.reshape(num_chunks, chunk_size) +# Apply atom type sorting to maintain correspondence with atomic coordinates idx_map_hess = idx_map_hess[idx_map] +# Flatten for use in both dimensions of the Hessian matrix idx_map_hess = idx_map_hess.flatten()
674-676
: Clarify the Hessian size commentThe comment about the Hessian size should be more explicit about its relationship with the number of atoms.
- ndof = ( - 3 * ndof * 3 * ndof - ) # size of hessian is 3Natoms * 3Natoms + ndof = ( + 3 * ndof * 3 * ndof + ) # Hessian size is (3N)×(3N) where N is number of atoms, + # as each atom has 3 coordinates (x,y,z)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
deepmd/utils/data.py
(1 hunks)
🔇 Additional comments (1)
deepmd/utils/data.py (1)
663-677
: Add dimension validation for Hessian data
The Hessian reshaping logic should validate input dimensions before processing to prevent silent errors or unexpected behavior.
Signed-off-by: Anchor Yu <[email protected]>
Signed-off-by: Anchor Yu <[email protected]>
I'm much curious about this Hessian DP model
Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
doc/model/train-energy-hessian.md (1)
48-54
: Add validation guidance for Hessian matrices.Consider adding guidance on how users can validate their Hessian matrices format, including:
- A script or command to verify the shape and dimensions
- Common pitfalls when preparing the data
- Example of a minimal valid Hessian matrix
Would you like me to help draft the validation guidance section?
deepmd/pt/model/model/ener_model.py (2)
42-44
: Add docstring to document the enable_hessian methodThe method implementation is correct, but it would benefit from documentation explaining its purpose, effects, and any prerequisites.
def enable_hessian(self): + """Enable Hessian computation for energy calculations. + + This method enables the computation of energy Hessian matrices by: + 1. Setting up the required computational graph for energy Hessian + 2. Enabling the internal Hessian flag + """ self.requires_hessian("energy") self._hessian_enabled = True
62-63
: Document Hessian tensor structure and add shape validationWhile the Hessian output is correctly added, it would be beneficial to document the expected tensor structure and add shape validation.
if self._hessian_enabled: + # Validate Hessian shape: [batch_size, natoms * 3, natoms * 3] + hessian = out_def_data["energy_derv_r_derv_r"] + expected_shape = hessian.shape[-2:] + if not (len(expected_shape) == 2 and expected_shape[0] == expected_shape[1]): + raise ValueError(f"Invalid Hessian shape: {hessian.shape}") output_def["hessian"] = out_def_data["energy_derv_r_derv_r"]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
deepmd/pt/infer/deep_eval.py
(5 hunks)deepmd/pt/model/model/__init__.py
(1 hunks)deepmd/pt/model/model/ener_model.py
(4 hunks)deepmd/pt/model/model/make_hessian_model.py
(1 hunks)deepmd/utils/argcheck.py
(2 hunks)doc/model/train-energy-hessian.md
(1 hunks)source/tests/pt/model/test_dp_hessian_model.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- deepmd/pt/model/model/init.py
- source/tests/pt/model/test_dp_hessian_model.py
- deepmd/pt/infer/deep_eval.py
- deepmd/pt/model/model/make_hessian_model.py
🧰 Additional context used
📓 Learnings (1)
deepmd/utils/argcheck.py (1)
Learnt from: 1azyking
PR: deepmodeling/deepmd-kit#4169
File: deepmd/utils/argcheck.py:1982-2117
Timestamp: 2024-11-12T05:47:21.643Z
Learning: The `loss_ener_hess` and `loss_ener` functions should remain separate to avoid confusion, despite code duplication.
🪛 Markdownlint (0.37.0)
doc/model/train-energy-hessian.md
39-39: null
Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🔇 Additional comments (10)
doc/model/train-energy-hessian.md (4)
1-5
: Add standard frontmatter and TOC entry.
Please follow the project's documentation standards:
- Add frontmatter section following the format from dpa2.md
- Add this document to the table of contents
7-9
: LGTM: Clear warning about model freezing limitations.
The warning effectively communicates the limitations and fallback behavior when freezing Hessian models.
18-29
: Update loss type to match implementation.
The loss type should be "ener_hess" instead of "ener" when training with Hessian calculations.
39-46
: Specify language for fenced code block.
To improve consistency and adhere to Markdown best practices, specify a language for the fenced code block.
🧰 Tools
🪛 Markdownlint (0.37.0)
39-39: null
Fenced code blocks should have a language specified
(MD040, fenced-code-language)
deepmd/utils/argcheck.py (2)
2179-2180
: LGTM! Documentation is clear and consistent.
The documentation for the new Hessian parameters follows the project's style and clearly explains their purpose.
2235-2248
: Implementation looks good, verify loss calculation handling.
The new Hessian parameters are well-implemented with appropriate types, defaults, and documentation. The implementation follows the project's patterns consistently.
Run the following script to verify the loss calculation code handles these new parameters:
✅ Verification successful
Implementation is correctly handled in loss calculation code
The verification shows that the new Hessian parameters are properly integrated into the loss calculation system:
- The parameters are correctly used in
deepmd/pt/loss/ener.py
where:- They control the Hessian loss activation via
has_h = (start_pref_h != 0.0 and limit_pref_h != 0.0)
- The loss calculation properly scales between start and limit prefactors
- The Hessian loss is correctly computed and added to the total loss
- They control the Hessian loss activation via
- The implementation is tested in
source/tests/pt/test_loss.py
- Example configurations in
examples/hessian/
demonstrate proper usage - The training system correctly enables Hessian mode when these parameters are set via
whether_hessian()
function
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that the loss calculation code properly handles the new Hessian parameters
# Check for loss calculation implementation
rg -A 10 "start_pref_h|limit_pref_h" --type python
# Check for test coverage
rg -A 10 "test.*hessian.*loss" --type python
Length of output: 161
Script:
#!/bin/bash
# Let's try without file type filtering since it's not supported
# Check for loss calculation implementation
rg -A 10 "start_pref_h|limit_pref_h"
# Check for test coverage
rg -A 10 "test.*hessian.*loss"
# Additionally search for Hessian-related loss calculations
rg -A 10 "hessian.*loss|loss.*hessian"
Length of output: 16280
deepmd/pt/model/model/ener_model.py (4)
18-20
: LGTM: Import follows project conventions
The new import for Hessian model functionality follows the project's import conventions and is properly placed with other model-related imports.
40-40
: LGTM: Proper initialization of Hessian flag
The _hessian_enabled
flag is properly initialized as a private attribute with an appropriate default value of False, following Python conventions.
26-26
: Verify the Hessian model wrapper implementation
The Hessian wrapper is correctly applied after the base model creation. Let's verify the wrapper implementation:
✅ Verification successful
The previous ast-grep patterns didn't work because of syntax issues. Let's try with ripgrep to find the implementation and usage patterns, along with checking the actual file content to understand the context better.
Based on the search results, I can now generate the final response:
Hessian model wrapper is correctly implemented and follows established patterns
The verification shows that:
- The Hessian wrapper is properly applied after the base model creation (
DPEnergyModel_ = make_hessian_model(DPEnergyModel_)
) - The wrapper follows the project's model decoration pattern, similar to other model wrappers like
make_model
- The implementation includes proper Hessian functionality handling through
enable_hessian()
and output translation intranslated_output_def()
- Test coverage exists in
test_make_hessian_model.py
validating the wrapper's functionality
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check the make_hessian_model implementation and its usage pattern
ast-grep --pattern 'def make_hessian_model($_) {
$$$
}'
# Check for similar wrapper patterns in the codebase
ast-grep --pattern '$_ = make_$_($_)'
Length of output: 99
Script:
#!/bin/bash
# Search for make_hessian_model definition and usage with ripgrep
rg "make_hessian_model" -A 5
# Search for similar model wrapper patterns
rg "make_.+_model" -A 5
# Look for the actual file content
cat deepmd/pt/model/model/ener_model.py
Length of output: 23253
99-100
: Consider adding Hessian support to forward_lower method
While Hessian handling is correctly implemented in the forward pass, consider adding corresponding support in the forward_lower
method for consistency.
Thank you for your comment! Here's a response to your questions:
If you have any additional questions, feel free to leave another comment! |
Signed-off-by: Anchor Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/model/model/ener_model.py
(4 hunks)
🔇 Additional comments (4)
deepmd/pt/model/model/ener_model.py (4)
18-20
: LGTM: Clean import addition
The new import for make_hessian_model
is properly organized with other model-related imports.
39-39
: LGTM: Clean initialization of Hessian flag
The _hessian_enabled
flag is properly initialized with a sensible default value.
63-64
: LGTM: Clean addition of Hessian output definition
The Hessian output is properly guarded by the enabled flag and follows the existing pattern.
100-101
: Verify Hessian support in forward_lower method
While the Hessian support is properly added to the forward method, please verify if similar changes are needed in the forward_lower
method for consistency.
Let's check if there are any references to Hessian computation in lower-level implementations:
Signed-off-by: Anchor Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
source/tests/pt/test_change_bias.py (2)
126-129
: Consider refactoring common verification logicThe verification logic is duplicated across test methods. Consider extracting the common model initialization and verification code into a helper method.
+ def _verify_bias_change(self, state_dict, expected_bias): + model_params = state_dict["model"]["_extra_state"]["model_params"] + model_for_wrapper = get_model_for_wrapper( + model_params, + _loss_params=self.loss_params, + ) + wrapper = ModelWrapper(model_for_wrapper) + wrapper.load_state_dict(state_dict["model"]) + updated_bias = wrapper.model["Default"].get_out_bias() + torch.testing.assert_close(updated_bias, expected_bias)
150-153
: Consider parameterizing the test with more test casesThe test could be enhanced by using
@pytest.mark.parametrize
to test various user-defined bias values, including edge cases and invalid inputs.Example parameterization:
@pytest.mark.parametrize("user_bias", [ [0.1, 3.2, -0.5], # current case [0.0, 0.0, 0.0], # zero bias [-1e6, 1e6, 0], # extreme values ]) def test_change_bias_with_user_defined(self, user_bias): # existing test code
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_change_bias.py
(4 hunks)
🔇 Additional comments (2)
source/tests/pt/test_change_bias.py (2)
90-90
: LGTM: Loss parameters initialization in setUp
The initialization of loss parameters from the config is correctly placed in the setUp method, ensuring fresh configuration for each test case.
100-103
: Verify test compatibility with Hessian mode
The test now correctly passes loss parameters to the model wrapper. However, we should verify that this test works correctly when Hessian mode is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to add tutorials with commands in the following situations:
- users want to train hessian from scratch and freeze an energy model.
- users want to train hessian from a pretrained energy model and freeze an energy model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)
1228-1231
: Consider improving function name and simplifying logicWhile the function logic is correct, consider these improvements for better readability:
- Rename to
is_hessian_required
orshould_use_hessian
to better reflect its purpose- Simplify to a single return statement
-def whether_hessian(loss_params): - loss_type = loss_params.get("type", "ener") - return loss_type == "ener" and loss_params.get("start_pref_h", 0.0) > 0.0 +def is_hessian_required(loss_params): + return loss_params.get("type", "ener") == "ener" and loss_params.get("start_pref_h", 0.0) > 0.0
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/train/training.py
(5 hunks)deepmd/utils/argcheck.py
(2 hunks)
🧰 Additional context used
📓 Learnings (1)
deepmd/utils/argcheck.py (1)
Learnt from: 1azyking
PR: deepmodeling/deepmd-kit#4169
File: deepmd/utils/argcheck.py:1982-2117
Timestamp: 2024-11-12T05:47:21.643Z
Learning: The `loss_ener_hess` and `loss_ener` functions should remain separate to avoid confusion, despite code duplication.
🔇 Additional comments (4)
deepmd/utils/argcheck.py (2)
2179-2180
: LGTM! Documentation strings are clear and consistent.
The documentation strings for the new Hessian prefactor parameters follow the established pattern and clearly explain their purpose.
2235-2248
: LGTM! Implementation of Hessian prefactor parameters is correct.
The implementation:
- Uses consistent types ([float, int]) with other prefactor parameters
- Sets appropriate default values (0.0) to make Hessian loss opt-in
- Follows the same argument structure as other loss prefactors
deepmd/pt/train/training.py (2)
28-28
: LGTM!
The import of EnergyHessianStdLoss
is correctly placed with other loss imports.
1235-1238
: LGTM!
The Hessian loss integration follows the established pattern and correctly:
- Uses whether_hessian for conditional activation
- Sets required parameters
- Returns appropriate loss type
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests