@@ -17,7 +17,6 @@ def __init__(
1717 ):
1818 self ._repo_id = repo_id
1919 self ._sample_size = sample_size
20-
2120 self ._num_procs = cpu_count () - 1
2221 self ._data = self .load (split , cols )
2322 secho (f"Total records loaded: { len (self ._data )} " , fg = "green" )
@@ -70,7 +69,7 @@ def format_document(**kwargs):
7069 def load (self , split : str , cols : list [str ] = None ):
7170 secho (
7271 f"Loading data from { self ._repo_id } using: { self ._num_procs } cores" ,
73- fg = "yellow" ,
72+ fg = ( 229 , 192 , 123 ) ,
7473 )
7574 data = load_dataset (self .repo_id , num_proc = self ._num_procs , split = split , columns = cols )
7675 if self ._sample_size is None :
@@ -79,12 +78,12 @@ def load(self, split: str, cols: list[str] = None):
7978 return data .shuffle (seed = RANDOM_STATE ).select (range (self ._sample_size ))
8079
8180 def generate_pairs (self ):
82- self . pairs = self ._data
83- metadata = [{"source" : self .name }] * len (self . pairs )
84- self . pairs = self . pairs .add_column ("metadata" , metadata )
85- secho (f"Generated { len (self . pairs )} pairs." , fg = "green" )
86- secho (f"First sample: { self . pairs [0 ]} " , fg = "yellow" )
87- return self . pairs
81+ pairs = self ._data
82+ metadata = [{"source" : self .name }] * len (pairs )
83+ pairs = pairs .add_column ("metadata" , metadata )
84+ secho (f"Generated { len (pairs )} pairs." , fg = "green" )
85+ secho (f"First sample: { pairs [0 ]} " , fg = ( 229 , 192 , 123 ) )
86+ return pairs
8887
8988 def generate_triplets (self , threshold = 3.0 ):
9089 positives = self .generate_positives (threshold = threshold ).to_pandas ()
@@ -98,10 +97,10 @@ def generate_triplets(self, threshold=3.0):
9897 triplets ["metadata" ] = triplets [metadata_cols ].apply (lambda x : json .dumps (x .to_dict ()), axis = 1 )
9998 triplets = triplets .drop (columns = metadata_cols )
10099
101- self . triplets = Dataset .from_pandas (triplets , preserve_index = False )
102- secho (f"Generated { len (self . triplets )} triplets." , fg = "green" )
103- secho (f"First sample: { self . triplets [0 ]} " , fg = "yellow" )
104- return self . triplets
100+ triplets = Dataset .from_pandas (triplets , preserve_index = False )
101+ secho (f"Generated { len (triplets )} triplets." , fg = "green" )
102+ secho (f"First sample: { triplets [0 ]} " , fg = ( 229 , 192 , 123 ) )
103+ return triplets
105104
106105 def generate_positives (self , threshold ):
107106 pos = self ._data .filter (lambda x : x ["relevance" ] >= threshold ).map (
0 commit comments