This commit is contained in:
Harshil 2023-07-17 21:39:34 -05:00 committed by GitHub
commit 0ec0c2bc29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -387,15 +387,16 @@ class Trainer(object):
fold=i, fold=i,
) )
else: else:
raise ValueError("Sure you want to do multiple fold training") a = input("Are you sure you want to do multiple fold training? (y/N)")
for mb_generator, steps_per_epoch, val_data, test_data in self.mb_loader(full_df=df): if a.lower() == "y":
self._train_single_fold( for mb_generator, steps_per_epoch, val_data, test_data in self.mb_loader(full_df=df):
mb_generator=mb_generator, self._train_single_fold(
val_data=val_data, mb_generator=mb_generator,
test_data=test_data, val_data=val_data,
steps_per_epoch=steps_per_epoch, test_data=test_data,
fold=i, steps_per_epoch=steps_per_epoch,
) fold=i,
i += 1 )
if i == 3: i += 1
break if i == 3:
break