-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathAPI_backend.py
More file actions
151 lines (121 loc) · 5.18 KB
/
API_backend.py
File metadata and controls
151 lines (121 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
from flask import Flask, jsonify
from flask_cors import CORS
from flask_restful import Api, Resource, reqparse
import pickle
import pandas as pd
import numpy as np
import json
from transformers import AutoModelForTokenClassification, AutoTokenizer
def get_class_map_from_message_NEW(input_message:str) -> dict:
class_number_to_name_dict = {0 : '',
1 : 'btst',
2 : 'delivery',
3 : 'enter',
4 : 'momentum',
5 : 'exit',
6 : 'exit2',
7 : 'exit3',
8 : 'intraday',
9 : 'sl',
10 : 'symbol',
11 : 'momentum'}
# class_number_to_name_dict
########### PREDICT TEXT AND CLASSIFY WORDS ##########
print(input_message)
print(type(input_message))
input_message = str(input_message)
print(input_message)
print(type(input_message))
ip1 = tokenizer(input_message,return_tensors='pt')
op1 = model(**ip1)
current_word = ''
sentence = []
sentence_class= []
sentence_class_name= []
list_of_decoded_words = tokenizer.batch_decode(ip1['input_ids'][0])
last_word_contained_hash = False
last_classification_numner = 0
last_decoded_word = ''
for onet in range(len(ip1['input_ids'][0])):
this_token = ip1['input_ids'][0][onet]
this_classification = op1.logits[0][onet].tolist()
this_decoded_word = list_of_decoded_words[onet]
this_classification_number = np.argmax(this_classification)
if(this_decoded_word=='[CLS]' or this_decoded_word=='[SEP]'):
continue
# print(f'{this_decoded_word=}')
# # print(f'{this_classification=}')
# print(f'{this_classification_number=}')
this_word_contains_hash= '#' in this_decoded_word
if('#' in this_decoded_word):
hash_replaced_word = this_decoded_word.replace('#','')
# print(f'''{hash_replaced_word=}''')
current_word = current_word+hash_replaced_word
# print(f'{current_word=}')
last_word_contained_hash=True
elif((this_classification_number==last_classification_numner) and ((this_decoded_word=='.') or (last_decoded_word=='.'))):
last_classification_numner = this_classification_number
current_word = current_word+this_decoded_word
else:
# print('========== insidious ===============')
sentence.append(current_word)
sentence_class.append(last_classification_numner)
sentence_class_name.append(class_number_to_name_dict[last_classification_numner])
# print(f'{current_word=}')
# print(f'{sentence=}')
# print(f'{last_classification_numner=}')
# print(f'{sentence_class=}')
# print(f'{current_word=}')
current_word=this_decoded_word
last_classification_numner = this_classification_number
last_word_contained_hash=False
last_decoded_word = this_decoded_word
# print('======================================')
sentence.append(current_word)
sentence_class.append((last_classification_numner))
sentence_class_name.append(class_number_to_name_dict[last_classification_numner])
results_json = {'sentence':str(sentence),
'sentence_class':str(sentence_class),
'sentence_class_name':str(sentence_class_name),
}
#resultsdf = pd.DataFrame(results_json)
# display(resultsdf)
return results_json
app = Flask(__name__)
CORS(app)
api = Api(app)
# Health check endpoint
class HealthCheck(Resource):
def get(self):
return jsonify({'status': 'healthy', 'model': 'loaded'})
api.add_resource(HealthCheck, '/health')
# Create parser for the payload data
parser = reqparse.RequestParser()
parser.add_argument('data')
def convert(o):
if isinstance(o, np.generic): return o.item()
raise TypeError
# Define how the api will respond to the post requests
class MessageNER(Resource):
def post(self):
args = parser.parse_args()
X = np.array(json.loads(args['data']))
prediction = get_class_map_from_message_NEW(X)
#return jsonify(prediction)
return prediction
api.add_resource(MessageNER, '/classifyner')
if __name__ == '__main__':
from dotenv import load_dotenv
load_dotenv()
###### LOAD PRETRAINED MODEL FROM HUGGINGFACE autoTrain #################
model_id = os.getenv('HUGGINGFACE_MODEL_ID', 'hemangjoshi37a/autotrain-ratnakar_1000_sample_curated-1474454086')
hf_token = os.getenv('HUGGINGFACE_TOKEN')
model_kwargs = {}
if hf_token:
model_kwargs['token'] = hf_token
model = AutoModelForTokenClassification.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id, **model_kwargs)
port = int(os.getenv('FLASK_PORT', 3737))
debug = os.getenv('FLASK_DEBUG', 'True').lower() == 'true'
app.run(debug=debug, port=port)