Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add index_to_name.json #6

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 36 additions & 50 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,26 @@ Read more about this pre-trained model [here.](https://towardsdatascience.com/nl

**In collaboration with [Charan Pothireddi](https://www.linkedin.com/in/sree-charan-pothireddi-6a0a3587/) and [Parabole.ai](https://www.linkedin.com/in/sree-charan-pothireddi-6a0a3587/)**

## Prerequisites
The further pre-trained ESG-BERT model can be found [here](https://drive.google.com/drive/folders/1yfNpMvByz3fJMsOqir3SerS6PwsRS2rt?usp=sharing) at this GitHub repository. It is a PyTorch model but it can be converted into a Tensorflow model. They can be fine-tuned using either framework. I found the PyTorch framework to be a lot cleaner, and easier to replicate with other models. However, serving the final fine-tuned model is a lot easier on TensorFlow, than on PyTorch. 

You can download the ESG-BERT model (named pytorch_model.bin) along with config.json and vovab.txt fles here. BERT base model was further pre-trained on Sustainable Investing text corpus, resulting in a domain specific model. You need the all of those 3 files for fine-tuning.
You can download the ESG-BERT model (named `pytorch_model.bin`) along with `config.json` and `vocab.txt` files here. The BERT base model was further pre-trained on Sustainable Investing text corpus, resulting in a domain specific model. You need the all of those 3 files for fine-tuning.

For fine-tuning the model, you can use this command to load it into PyTorch. 
```
model = BertForSequenceClassification.from_pretrained(
'path/to/dir/containing/ESG-BERT',
num_labels = num, #number of classifications
output_attentions = False, # Whether the model returns attentions weights.
output_hidden_states = False, # Whether the model returns all hidden-states.
)
model.to(device)

```
The fine-tuned model for text classification is also available [here](https://drive.google.com/drive/folders/1Qz4HP3xkjLfJ6DGCFNeJ7GmcPq65_HVe?usp=sharing). It can be used directly to make predictions using just a few steps. 
First, download the fine-tuned pytorch_model.bin, config.json, and vocab.txt into your local directory. Make sure to place all of them into the same directory, mine is called "bert_model". 
First, download the fine-tuned `pytorch_model.bin`, `config.json`, and `vocab.txt` into your local directory. Make sure to place all of them into the same directory, mine is called `bert_model`.

### Install dependencies
JDK 11 is needed to serve the model. Go ahead and install it from the Oracle downloads page. Now we are ready to set up TorcheServe.
TorchServe is a model serving architecture for PyTorch models, go ahead and install that using pip. You can also use conda for the installation. We also need pytorch and transformers installed.
```
``` bash
pip install torchserve torch-model-archiver
pip install torchvision
pip install transformers
```
Next up, we'll set up the handler script. It is a basic handler for text classification that can be improved upon. Save this script as "handler.py" in your directory. [1]
```

### Set up the handler script
Next up, we'll set up the handler script. It is a basic handler for text classification that can be improved upon. Save this script as `handler.py` in your directory. [1]
``` python
from abc import ABC
import json
import logging
Expand Down Expand Up @@ -115,52 +109,44 @@ return data
raise e

```
TorcheServe uses a format called MAR (Model Archive). We can convert our PyTorch model to a .mar file using this command:
```
torch-model-archiver --model-name "bert" --version 1.0 --serialized-file ./bert_model/pytorch_model.bin --extra-files "./bert_model/config.json,./bert_model/vocab.txt" --handler "./handler.py"
## Creating a torchserve model archive
Create a new model directory:
``` bash
mkdir model_store
```
Move the .mar file into a new directory: 
```
mkdir model_store && mv bert.mar model_store
TorcheServe uses a format called `MAR` (Model Archive). We can convert our PyTorch model to a `.mar` file using this command:
``` bash
torch-model-archiver --model-name "bert" --version 1.0 --serialized-file ./bert_model/pytorch_model.bin --extra-files "./bert_model/config.json,./bert_model/vocab.txt,./bert_model/index_to_name.json" --handler "./handler.py" --export-path "model_store/"
```
The resulting mar file will be stored in the `model_store` directory we just created.

## Serve the model
Finally, we can start TorchServe using the command: 
```
torchserve --start --model-store model_store --models bert=bert.mar
```

## Test the model
We can now query the model from another terminal window using the Inference API. We pass a text file containing text that the model will try to classify. 

```
curl -X POST http://127.0.0.1:8080/predictions/bert -T predict.txt
```
This returns a label number which correlates to a textual label. This is stored in the label_dict.txt dictionary file. 
```
__label__Business_Ethics : 0
__label__Data_Security : 1
__label__Access_And_Affordability : 2
__label__Business_Model_Resilience : 3
__label__Competitive_Behavior : 4
__label__Critical_Incident_Risk_Management : 5
__label__Customer_Welfare : 6
__label__Director_Removal : 7
__label__Employee_Engagement_Inclusion_And_Diversity : 8
__label__Employee_Health_And_Safety : 9
__label__Human_Rights_And_Community_Relations : 10
__label__Labor_Practices : 11
__label__Management_Of_Legal_And_Regulatory_Framework : 12
__label__Physical_Impacts_Of_Climate_Change : 13
__label__Product_Quality_And_Safety : 14
__label__Product_Design_And_Lifecycle_Management : 15
__label__Selling_Practices_And_Product_Labeling : 16
__label__Supply_Chain_Management : 17
__label__Systemic_Risk_Management : 18
__label__Waste_And_Hazardous_Materials_Management : 19
__label__Water_And_Wastewater_Management : 20
__label__Air_Quality : 21
__label__Customer_Privacy : 22
__label__Ecological_Impacts : 23
__label__Energy_Management : 24
__label__GHG_Emissions : 25
This returns a textual label, defined in the `index_to_name.json` file.

## Fine-tuning the model yourself
For fine-tuning the model, you can use this command to load it into PyTorch. 
``` python
model = BertForSequenceClassification.from_pretrained(
'path/to/dir/containing/ESG-BERT',
num_labels = num, #number of classifications
output_attentions = False, # Whether the model returns attentions weights.
output_hidden_states = False, # Whether the model returns all hidden-states.
)
model.to(device)
```


References:
[1] - ---

Expand Down
28 changes: 28 additions & 0 deletions index_to_name.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"0":"Business Ethics",
"1":"Data Security",
"2":"Access and Affordability",
"3":"Business Model Resilience",
"4":"Competitive Behavior",
"5":"Critical Incident Risk Management",
"6":"Customer Welfare",
"7":"Director Removal",
"8":"Employee Engagement Inclusion And Diversity",
"9":"Employee Health And Safety",
"10":"Human Rights And Community Relations",
"11":"Labor Practices",
"12":"Management Of Legal And Regulatory Framework",
"13":"Physical Impacts Of Climate Change",
"14":"Product Quality And Safety",
"15":"Product Design And Lifecycle Management",
"16":"Selling Practices And Product Labeling",
"17":"Supply Chain Management",
"18":"Systemic Risk Management",
"19":"Waste And Hazardous Materials Management",
"20":"Water And Wastewater Management",
"21":"Air Quality",
"22":"Customer Privacy",
"23":"Ecological Impacts",
"24":"Energy Management",
"25":"GHG Emissions"
}