11#!/usr/bin/env python3
22# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
33
4- from typing import List
4+ from typing import List , Optional , Tuple
55
66import torch
77from reagent .core import types as rlt
88from reagent .models import FullyConnectedNetwork
99from reagent .models .base import ModelBase
10- from torchrec .models .dlrm import SparseArch
10+ from torchrec .models .dlrm import SparseArchRO
1111from torchrec .modules .embedding_modules import EmbeddingBagCollection
1212from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
1313
1616@torch .fx .wrap
1717def fetch_id_list_features (
1818 state : rlt .FeatureData , action : rlt .FeatureData
19- ) -> KeyedJaggedTensor :
20- assert state .id_list_features is not None or action .id_list_features is not None
21- if state .id_list_features is not None and action .id_list_features is None :
22- sparse_features = state .id_list_features
23- elif state .id_list_features is None and action .id_list_features is not None :
24- sparse_features = action .id_list_features
25- elif state .id_list_features is not None and action .id_list_features is not None :
26- sparse_features = KeyedJaggedTensor .concat (
27- [state .id_list_features , action .id_list_features ]
28- )
29- else :
19+ ) -> Tuple [Optional [KeyedJaggedTensor ], Optional [KeyedJaggedTensor ]]:
20+ assert (
21+ state .id_list_features is not None
22+ or state .id_list_features_ro is not None
23+ or action .id_list_features is not None
24+ or action .id_list_features_ro is not None
25+ )
26+
27+ def _get_sparse_features (
28+ id_list_features_1 , id_list_features_2
29+ ) -> Optional [KeyedJaggedTensor ]:
30+ sparse_features = None
31+ if id_list_features_1 is not None and id_list_features_2 is None :
32+ sparse_features = id_list_features_1
33+ elif id_list_features_1 is None and id_list_features_2 is not None :
34+ sparse_features = id_list_features_2
35+ elif id_list_features_1 is not None and id_list_features_2 is not None :
36+ sparse_features = KeyedJaggedTensor .concat (
37+ [id_list_features_1 , id_list_features_2 ]
38+ )
39+ return sparse_features
40+
41+ sparse_features = _get_sparse_features (
42+ state .id_list_features , action .id_list_features
43+ )
44+ sparse_features_ro = _get_sparse_features (
45+ state .id_list_features_ro , action .id_list_features_ro
46+ )
47+ if sparse_features is None and sparse_features_ro is None :
3048 raise ValueError
49+
3150 # TODO: add id_list_score_features
32- return sparse_features
51+ return sparse_features , sparse_features_ro
3352
3453
3554class SparseDQN (ModelBase ):
@@ -41,7 +60,8 @@ class SparseDQN(ModelBase):
4160 def __init__ (
4261 self ,
4362 state_dense_dim : int ,
44- embedding_bag_collection : EmbeddingBagCollection ,
63+ embedding_bag_collection : Optional [EmbeddingBagCollection ],
64+ embedding_bag_collection_ro : Optional [EmbeddingBagCollection ],
4565 action_dense_dim : int ,
4666 overarch_dims : List [int ],
4767 activation : str = "relu" ,
@@ -51,17 +71,37 @@ def __init__(
5171 output_dim : int = 1 ,
5272 ) -> None :
5373 super ().__init__ ()
54- self .sparse_arch : SparseArch = SparseArch (embedding_bag_collection )
74+ self .sparse_arch : SparseArchRO = SparseArchRO (
75+ embedding_bag_collection , embedding_bag_collection_ro
76+ )
77+
78+ self .sparse_embedding_dim : int = (
79+ sum (
80+ [
81+ len (embc .feature_names ) * embc .embedding_dim
82+ for embc in embedding_bag_collection .embedding_bag_configs ()
83+ ]
84+ )
85+ if embedding_bag_collection is not None
86+ else 0
87+ )
5588
56- self .sparse_embedding_dim : int = sum (
57- [
58- len (embc .feature_names ) * embc .embedding_dim
59- for embc in embedding_bag_collection .embedding_bag_configs ()
60- ]
89+ self .sparse_embedding_dim_ro : int = (
90+ sum (
91+ [
92+ len (embc .feature_names ) * embc .embedding_dim
93+ for embc in embedding_bag_collection .embedding_bag_configs ()
94+ ]
95+ )
96+ if embedding_bag_collection is not None
97+ else 0
6198 )
6299
63100 self .input_dim : int = (
64- state_dense_dim + self .sparse_embedding_dim + action_dense_dim
101+ state_dense_dim
102+ + self .sparse_embedding_dim
103+ + self .sparse_embedding_dim_ro
104+ + action_dense_dim
65105 )
66106 layers = [self .input_dim ] + overarch_dims + [output_dim ]
67107 activations = [activation ] * len (overarch_dims ) + [final_activation ]
@@ -76,11 +116,20 @@ def forward(self, state: rlt.FeatureData, action: rlt.FeatureData) -> torch.Tens
76116 (state .float_features , action .float_features ), dim = - 1
77117 )
78118 batch_size = dense_features .shape [0 ]
79- sparse_features = fetch_id_list_features (state , action )
119+ sparse_features , sparse_features_ro = fetch_id_list_features (state , action )
80120 # shape: batch_size, num_sparse_features, embedding_dim
81- embedded_sparse = self .sparse_arch (sparse_features )
82- # shape: batch_size, num_sparse_features * embedding_dim
83- embedded_sparse = embedded_sparse .reshape (batch_size , - 1 )
84- concatenated_dense = torch .cat ((dense_features , embedded_sparse ), dim = - 1 )
121+ embedded_sparse , embedded_sparse_ro = self .sparse_arch (
122+ sparse_features , sparse_features_ro
123+ )
124+ features_list : List [torch .Tensor ] = [dense_features ]
125+ if embedded_sparse is not None :
126+ # shape: batch_size, num_sparse_features * embedding_dim
127+ embedded_sparse = embedded_sparse .reshape (batch_size , - 1 )
128+ features_list .append (embedded_sparse )
129+ if embedded_sparse_ro is not None :
130+ # shape: batch_size, num_sparse_features * embedding_dim
131+ embedded_sparse_ro = embedded_sparse_ro .reshape (batch_size , - 1 )
132+ features_list .append (embedded_sparse_ro )
85133
134+ concatenated_dense = torch .cat (features_list , dim = - 1 )
86135 return self .q_network (concatenated_dense )
0 commit comments