Skip to content

Commit 8c99e4b

Browse files
savedmodel generation
1 parent 1370df7 commit 8c99e4b

31 files changed

+7943
-186
lines changed

modelzoo/BST/pb_to_pbtxt.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from tensorflow.python.saved_model import loader_impl
2+
from tensorflow.python.lib.io import file_io
3+
from tensorflow.python.platform import tf_logging as logging
4+
5+
source_dir="/home/deeprec/DeepRec/modelzoo/BST/savedmodels/1"
6+
7+
logging.info("before _parse_saved_model.")
8+
saved_model = loader_impl._parse_saved_model(source_dir)
9+
logging.info("_parse_saved_model done.")
10+
11+
path = source_dir + "/saved_model.pb"
12+
# write pbtxt graph
13+
file_io.write_string_to_file(path+"txt", str(saved_model))

modelzoo/BST/prepare_savedmodel.py

Lines changed: 738 additions & 0 deletions
Large diffs are not rendered by default.

modelzoo/BST/result/README.md

Lines changed: 0 additions & 2 deletions
This file was deleted.

modelzoo/BST/start_serving.cc

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#include <iostream>
2+
#include "serving/processor/serving/processor.h"
3+
#include "serving/processor/serving/predict.pb.h"
4+
5+
static const char* model_config = "{ \
6+
\"omp_num_threads\": 4, \
7+
\"kmp_blocktime\": 0, \
8+
\"feature_store_type\": \"memory\", \
9+
\"serialize_protocol\": \"protobuf\", \
10+
\"inter_op_parallelism_threads\": 10, \
11+
\"intra_op_parallelism_threads\": 10, \
12+
\"init_timeout_minutes\": 1, \
13+
\"signature_name\": \"serving_default\", \
14+
\"read_thread_num\": 3, \
15+
\"update_thread_num\": 2, \
16+
\"model_store_type\": \"local\", \
17+
\"checkpoint_dir\": \"/root/deeprec/DeepRec/modelzoo/BST/result/\", \
18+
\"savedmodel_dir\": \"/root/deeprec/DeepRec/modelzoo/BST/savedmodels/1657183908.2336085/\" \
19+
} ";
20+
21+
INPUT_FEATURES = [
22+
'pid', 'adgroup_id', 'cate_id', 'campaign_id', 'customer', 'brand',
23+
'user_id', 'cms_segid', 'cms_group_id', 'final_gender_code', 'age_level',
24+
'pvalue_level', 'shopping_level', 'occupation', 'new_user_class_level',
25+
'tag_category_list', 'tag_brand_list', 'price'
26+
]
27+
28+
struct input_format{
29+
string pid;
30+
string adgroup_id;
31+
string cate_id;
32+
string campaign_id;
33+
string customer;
34+
string brand;
35+
string user_id;
36+
string cms_segid;
37+
string cms_group_id;
38+
string final_gender_code;
39+
string age_level;
40+
string pvalue_level;
41+
string shopping_level;
42+
string occupation;
43+
string new_user_class_level;
44+
string tag_category_list;
45+
string tag_brand_list;
46+
string price;
47+
48+
};
49+
50+
::tensorflow::eas::ArrayProto get_proto(char* char_input,int dim,::tensorflow::eas::ArrayDataType type){
51+
::tensorflow::eas::ArrayShape array_shape;
52+
array_shape.add_dim(1);
53+
array_shape.add_dim(dim);
54+
// input array
55+
::tensorflow::eas::ArrayProto input;
56+
input.add_string_val(char_input);
57+
input.set_dtype(type);
58+
*(input.mutable_array_shape()) = array_shape;
59+
60+
return input;
61+
62+
}
63+
64+
int main(int argc, char** argv) {
65+
int state;
66+
void* model = initialize("", model_config, &state);
67+
if (state == -1) {
68+
std::cerr << "initialize error\n";
69+
}
70+
71+
// input format
72+
input_format inputs = {"430548_1007","669310","1665","360359","167792","247789","841908","81","10","1","4","2","3","0","3","8153|8153|8153|8154|8154|8154|1673|1673|1673|6115|6115|6115|1665|1665|1665|1665|1665|1665|8188|8188|8188|8188|8188|8188|8188|8188|8188|1665|1665|1665|8188|8188|8188|8188|8188|8188|8188|8188|8188|10747|10747|10747|10747|10747|10747|10747|10747|10747|10747|10747","197848|197848|197848|237004|237004|237004|330898|330898|330898|337445|337445|337445|258262|258262|258262|247789|247789|247789|339517|339517|339517|339517|339517|339517|339517|339517|339517|278878|278878|278878|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517|339517","6"}
73+
74+
// input type: float
75+
::tensorflow::eas::ArrayDataType dtype =
76+
::tensorflow::eas::ArrayDataType::DT_STRING;
77+
78+
79+
// ------------------------------------------------------------------------input setting------------------------------------------------------------------------------
80+
81+
::tensorflow::eas::ArrayProto input0 = get_proto(inputs.pid,strlen(inputs.pid),dtype);
82+
::tensorflow::eas::ArrayProto input1 = get_proto(inputs.adgroup_id,strlen(inputs.adgroup_id),dtype);
83+
::tensorflow::eas::ArrayProto input2 = get_proto(inputs.cate_id,strlen(inputs.cate_id),dtype);
84+
::tensorflow::eas::ArrayProto input3 = get_proto(inputs.campaign_id,strlen(inputs.campaign_id),dtype);
85+
::tensorflow::eas::ArrayProto input4 = get_proto(inputs.customer,strlen(inputs.customer),dtype);
86+
::tensorflow::eas::ArrayProto input5 = get_proto(inputs.brand,strlen(inputs.brand),dtype);
87+
::tensorflow::eas::ArrayProto input6 = get_proto(inputs.user_id,strlen(inputs.user_id),dtype);
88+
::tensorflow::eas::ArrayProto input7 = get_proto(inputs.cms_segid,strlen(inputs.cms_segid),dtype);
89+
::tensorflow::eas::ArrayProto input8 = get_proto(inputs.cms_group_id,strlen(inputs.cms_group_id),dtype);
90+
::tensorflow::eas::ArrayProto input9 = get_proto(inputs.final_gender_code,strlen(inputs.final_gender_code),dtype);
91+
::tensorflow::eas::ArrayProto input10 = get_proto(inputs.age_level,strlen(inputs.age_level),dtype);
92+
::tensorflow::eas::ArrayProto input11 = get_proto(inputs.pvalue_level,strlen(inputs.pvalue_level),dtype);
93+
::tensorflow::eas::ArrayProto input12 = get_proto(inputs.shopping_level,strlen(inputs.shopping_level),dtype);
94+
::tensorflow::eas::ArrayProto input13 = get_proto(inputs.occupation,strlen(inputs.occupation),dtype);
95+
::tensorflow::eas::ArrayProto input14 = get_proto(inputs.new_user_class_level,strlen(inputs.new_user_class_level),dtype);
96+
::tensorflow::eas::ArrayProto input15 = get_proto(inputs.tag_category_list,strlen(inputs.tag_category_list),dtype);
97+
::tensorflow::eas::ArrayProto input16 = get_proto(inputs.tag_brand_list,strlen(inputs.tag_brand_list),dtype);
98+
::tensorflow::eas::ArrayProto input17 = get_proto(inputs.price,strlen(inputs.price),dtype);
99+
100+
101+
102+
// PredictRequest
103+
::tensorflow::eas::PredictRequest req;
104+
req.set_signature_name("serving_default");
105+
req.add_output_filter("output:0");
106+
107+
(*req.mutable_inputs())["pid:0"] = input0;
108+
(*req.mutable_inputs())["adgroup_id:0"] = input1;
109+
(*req.mutable_inputs())["cate_id:0"] = input2;
110+
(*req.mutable_inputs())["campaign_id:0"] = input3;
111+
(*req.mutable_inputs())["customer:0"] = input4;
112+
(*req.mutable_inputs())["brand:0"] = input5;
113+
(*req.mutable_inputs())["user_id:0"] = input6;
114+
(*req.mutable_inputs())["cms_segid:0"] = input7;
115+
(*req.mutable_inputs())["cms_group_id:0"] = input8;
116+
(*req.mutable_inputs())["final_gender_code:0"] = input9;
117+
(*req.mutable_inputs())["age_level:0"] = input10;
118+
(*req.mutable_inputs())["pvalue_level:0"] = input11;
119+
(*req.mutable_inputs())["shopping_level:0"] = input12;
120+
(*req.mutable_inputs())["occupation:0"] = input13;
121+
(*req.mutable_inputs())["new_user_class_level:0"] = input14;
122+
(*req.mutable_inputs())["tag_category_list:0"] = input15;
123+
(*req.mutable_inputs())["tag_brand_list:0"] = input16;
124+
(*req.mutable_inputs())["price:0"] = input17;
125+
126+
size_t size = req.ByteSizeLong();
127+
void *buffer = malloc(size);
128+
req.SerializeToArray(buffer, size);
129+
130+
// do process
131+
void* output = nullptr;
132+
int output_size = 0;
133+
state = process(model, buffer, size, &output, &output_size);
134+
135+
// parse response
136+
std::string output_string((char*)output, output_size);
137+
::tensorflow::eas::PredictResponse resp;
138+
resp.ParseFromString(output_string);
139+
std::cout << "process returned state: " << state << ", response: " << resp.DebugString();
140+
141+
return 0;
142+
}
143+

modelzoo/BST/train.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import collections
88
from tensorflow.python.client import timeline
99
import json
10+
from glob import glob
1011

1112
from tensorflow.python.ops import partitioned_variables
1213

@@ -55,6 +56,10 @@
5556
'price': 50
5657
}
5758

59+
# next_element = iter.next()
60+
61+
# dict("pid":"1 2 3 4 5 ... 1000",)
62+
5863

5964
class BST():
6065
def __init__(self,
@@ -78,9 +83,10 @@ def __init__(self,
7883
if not inputs:
7984
raise ValueError("Dataset is not defined.")
8085
self._feature = inputs[0]
81-
self._label = inputs[1]
82-
86+
self._label = inputs[1]
87+
8388
self._unseq_column = user_column + item_column
89+
8490
self._tag_column = tag_column
8591
self._key_column = key_column
8692
self._batch_size = batch_size
@@ -101,7 +107,7 @@ def __init__(self,
101107
self._input_layer_partitioner = input_layer_partitioner
102108
self._dense_layer_partitioner = dense_layer_partitioner
103109

104-
self._create_model()
110+
self.r = self._create_model()
105111
with tf.name_scope('head'):
106112
self._create_loss()
107113
self._create_optimizer()
@@ -262,6 +268,7 @@ def _create_model(self):
262268
self._feature,
263269
self._unseq_column,
264270
cols_to_output_tensors=key_dict)
271+
265272

266273
# bst input
267274
with tf.variable_scope('bst_input_layer', reuse=tf.AUTO_REUSE):
@@ -298,6 +305,7 @@ def _create_model(self):
298305
seq_size=self._max_seqence_length,
299306
head_count=self._multi_head_size,
300307
name='bst')
308+
301309

302310
net = tf.concat([unseq_emb, bst_output], axis=1)
303311

@@ -323,8 +331,14 @@ def _create_model(self):
323331
net = tf.cast(net, dtype=tf.float32)
324332
self._logits = tf.layers.dense(inputs=net, units=1)
325333

326-
self.probability = tf.math.sigmoid(self._logits)
327-
self.output = tf.round(self.probability)
334+
self.probability = tf.math.sigmoid(self._logits,name="probability")
335+
self.output = tf.round(self.probability,name="output")
336+
337+
return self.output
338+
339+
340+
341+
328342

329343
# compute loss
330344
def _create_loss(self):
@@ -380,7 +394,7 @@ def parse_csv(value):
380394
all_columns.pop(BUY_COLUMN[0])
381395
features = all_columns
382396
return features, labels
383-
397+
384398
'''Work Queue Feature'''
385399
if args.workqueue and not args.tf:
386400
from tensorflow.python.ops.work_queue import WorkQueue
@@ -493,6 +507,43 @@ def build_feature_columns():
493507

494508
return user_column, item_column, tag_column, key_column
495509

510+
def deldir(dir):
511+
if not os.path.exists(dir):
512+
return False
513+
if os.path.isfile(dir):
514+
os.remove(dir)
515+
return
516+
for i in os.listdir(dir):
517+
t = os.path.join(dir, i)
518+
if os.path.isdir(t):
519+
deldir(t)
520+
else:
521+
os.unlink(t)
522+
os.removedirs(dir)
523+
524+
class MyHook(tf.train.SessionRunHook):
525+
def __init__(self,cur_model,export_dir):
526+
self.model = cur_model
527+
self.dir = export_dir
528+
529+
def before_run(self, run_context):
530+
"""返回SessionRunArgs和session run一起跑"""
531+
v1 = tf.get_collection('logis')
532+
prob = tf.get_collection('prob')
533+
return tf.train.SessionRunArgs(fetches=[v1, prob])
534+
535+
536+
def end(self,session):
537+
if os.path.exists(self.dir):
538+
deldir(self.dir)
539+
os.mkdir(self.dir)
540+
541+
tf.saved_model.simple_save(
542+
session,
543+
self.dir,
544+
inputs = self.model._feature,
545+
outputs = {"predict":self.model.output}
546+
)
496547

497548
def train(sess_config,
498549
input_hooks,
@@ -509,7 +560,7 @@ def train(sess_config,
509560
scaffold = tf.train.Scaffold(
510561
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
511562
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
512-
563+
# save_hook = MyHook(model,"/root/deeprec/DeepRec/modelzoo/BST/result/savedmodels")
513564
stop_hook = tf.train.StopAtStepHook(last_step=steps)
514565
log_hook = tf.train.LoggingTensorHook(
515566
{
@@ -518,6 +569,8 @@ def train(sess_config,
518569
}, every_n_iter=100)
519570
hooks.append(stop_hook)
520571
hooks.append(log_hook)
572+
dir = "/home/deeprec/DeepRec/modelzoo/BST/result/savedmodels"
573+
# hooks.append(save_hook)
521574
if args.timeline > 0:
522575
hooks.append(
523576
tf.train.ProfilerHook(save_steps=args.timeline,
@@ -545,8 +598,12 @@ def train(sess_config,
545598
summary_dir=checkpoint_dir,
546599
save_summaries_steps=args.save_steps,
547600
config=sess_config) as sess:
601+
548602
while not sess.should_stop():
549603
sess.run([model.loss, model.train_op])
604+
605+
606+
550607
print("Training completed.")
551608

552609

@@ -575,6 +632,9 @@ def eval(sess_config, input_hooks, model, data_init_op, steps, checkpoint_dir):
575632
writer.add_summary(events, _in)
576633
print("Evaluation complate:[{}/{}]".format(_in, steps))
577634
print("ACC = {}\nAUC = {}".format(eval_acc, eval_auc))
635+
636+
637+
578638

579639

580640
def main(tf_config=None, server=None):
@@ -629,6 +689,8 @@ def main(tf_config=None, server=None):
629689

630690
# create feature column
631691
user_column, item_column, tag_column, key_column = build_feature_columns()
692+
693+
632694

633695
# create variable partitioner for distributed training
634696
num_ps_replicas = len(tf_config['ps_hosts']) if tf_config else 0
@@ -682,6 +744,9 @@ def main(tf_config=None, server=None):
682744
if not (args.no_eval or tf_config):
683745
eval(sess_config, hooks, model, test_init_op, test_steps,
684746
checkpoint_dir)
747+
748+
749+
685750

686751

687752
def boolean_string(string):

modelzoo/DBMTL/pb_to_pbtxt.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from tensorflow.python.saved_model import loader_impl
2+
from tensorflow.python.lib.io import file_io
3+
from tensorflow.python.platform import tf_logging as logging
4+
5+
source_dir="/home/deeprec/DeepRec/modelzoo/DBMTL/savedmodels/1657784766"
6+
7+
logging.info("before _parse_saved_model.")
8+
saved_model = loader_impl._parse_saved_model(source_dir)
9+
logging.info("_parse_saved_model done.")
10+
11+
path = source_dir + "/saved_model.pb"
12+
# write pbtxt graph
13+
file_io.write_string_to_file(path+"txt", str(saved_model))

0 commit comments

Comments
 (0)