Skip to content

OpenCrate Class

This section provides a detailed reference for the OpenCrate class, which is the main entry point for interacting with the opencrate library.

Source code in opencrate/core/opencrate.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
class OpenCrate:
    config_eval_timeout: int = 60
    use_config: str = "default"
    start: Optional[Union[str, int]] = None
    tag: Optional[str] = None
    log_level: str = "info"
    replace: bool = False
    finetune: Optional[str] = None
    finetune_tag: Optional[str] = None
    snapshot: snp.Snapshot = snapshot
    script_name: Optional[str] = None
    _original_snapshot_reset: Optional[Callable[..., Any]] = None
    _original_snapshot_setup: Optional[Callable[..., Any]] = None
    _opencrate_subclass_initialized: bool = False
    jobs_meta_kwargs: Dict[str, Dict[str, Any]] = {}
    meta_saved = False
    registered_checkpoint_configs_list: List[Dict[str, Any]] = []
    available_jobs: List[str] = []

    @classmethod
    def job(
        cls,
        save_on_exception: bool = False,
        execute_once: bool = False,
        upstream_jobs: Optional[List[str]] = None,
        downstream_jobs: Optional[List[str]] = None,
        concurrent: bool = False,
    ) -> Callable[..., Callable[..., None]]:
        _concurrent = concurrent

        def decorator(job_func):
            job_name = job_func.__name__
            cls.jobs_meta_kwargs[job_name] = meta_kwargs()

            if job_name not in cls.available_jobs:
                cls.available_jobs.append(job_name)
            else:
                cls.snapshot.error(f"Job {job_name}() is already registered. Please use a unique name for each job.")

            def wrapper(
                self,
                upstream_params: Optional[Dict[str, Any]] = None,
                downstream_params: Optional[Dict[str, Any]] = None,
                schedule: str = "",
                schedule_timeout: Optional[str] = "",
                schedule_runout: Optional[int] = None,
                profile: Optional[bool] = False,
                *args,
                **kwargs,
            ):
                # Determine master concurrency setting
                master_concurrent = kwargs.pop("concurrent", _concurrent)

                # Execution sequence handling dependencies
                def execute_job_sequence():
                    save_function_name = f"save_{job_name}"

                    try:
                        # Upstream jobs - force sequential execution
                        if upstream_jobs:
                            for up_job in upstream_jobs:
                                if up_job not in self.available_jobs:
                                    raise ValueError(f"Upstream job {up_job}() is not registered.")

                                params = upstream_params.get(up_job, {}) if upstream_params else {}
                                getattr(self, up_job)(concurrent=False, **params)

                        # Main job execution
                        if (not execute_once) or (not self.jobs_meta_kwargs[job_name]["finished"]):
                            if profile:
                                self.jobs_meta_kwargs["train"]["profile"] = {
                                    "epoch": [],
                                    "batch": [],
                                }
                                os.makedirs(
                                    os.path.join(self.snapshot.dir_path, "profile"),
                                    exist_ok=True,
                                )
                                with MemoryProfiler(
                                    job_func,
                                    output_dir=os.path.join(
                                        self.snapshot.dir_path,
                                        "profile",
                                    ),
                                ) as profiler:
                                    # Run the profile for a specific workload
                                    # job_func(self, *args, **kwargs)
                                    profiler.run(self, *args, **kwargs)
                                    profiler.log_benchmarks(self.jobs_meta_kwargs["train"]["profile"])
                            else:
                                job_func(self, *args, **kwargs)

                            self.snapshot.info(f"Job {job_name}() has completed!")

                            if execute_once:
                                self.jobs_meta_kwargs[job_name]["finished"] = True
                            else:
                                self.jobs_meta_kwargs[job_name] = meta_kwargs()
                                if profile:
                                    self.jobs_meta_kwargs[job_name]["profile"] = {
                                        "epoch": [],
                                        "batch": [],
                                    }

                            # Save checkpoint if available
                            if hasattr(self, save_function_name):
                                getattr(self, save_function_name)()

                        # Downstream jobs - force sequential execution
                        if downstream_jobs:
                            for down_job in downstream_jobs:
                                if down_job not in self.available_jobs:
                                    raise ValueError(f"Downstream job {down_job}() is not registered.")

                                params = downstream_params.get(down_job, {}) if downstream_params else {}
                                getattr(self, down_job)(concurrent=False, **params)

                    except KeyboardInterrupt:
                        self.snapshot.info("Keyboard Interrupt occurred.")
                        if hasattr(self, save_function_name) and save_on_exception:
                            getattr(self, save_function_name)()
                            self.snapshot.info("Checkpoint saved successfully!")
                    except Exception as e:
                        self.snapshot.exception(e)
                        if hasattr(self, save_function_name) and save_on_exception:
                            getattr(self, save_function_name)()
                            self.snapshot.info("Checkpoint saved successfully!")

                # Scheduling logic with timeout/runout limits
                def run_scheduled_job():
                    # Validate schedule parameters
                    if schedule_timeout is None and schedule_runout is None:
                        raise ValueError("Scheduled jobs require either `schedule_timeout` or `schedule_runout` parameter")

                    # Convert timeout to seconds
                    timeout_seconds = None
                    if schedule_timeout:
                        parts = schedule_timeout.split(":")
                        if len(parts) != 3:
                            raise ValueError("schedule_timeout must be in 'hh:mm:ss' format")
                        h, m, s = map(int, parts)
                        timeout_seconds = h * 3600 + m * 60 + s

                    try:
                        h_str, m_str, s_str = schedule.split(":")
                        is_continuous = h_str == "*" and m_str == "*" and s_str == "*"

                        if is_continuous:
                            schedule_desc = "immediately after previous run completes"
                        else:
                            schedule_desc = ""
                            if s_str != "*":
                                schedule_desc = f"{s_str} seconds"
                            if m_str != "*":
                                if schedule_desc != "":
                                    schedule_desc = f"{m_str} minutes, {schedule_desc}"
                                else:
                                    schedule_desc = f"{m_str} minutes"
                            if h_str != "*":
                                if schedule_desc != "":
                                    schedule_desc = f"{h_str} hours, {schedule_desc}"
                                else:
                                    schedule_desc = f"{h_str} hours"
                            schedule_desc = f"every {schedule_desc}"

                        self.snapshot.info(f"{job_name}() will run {schedule_desc}.")

                        # Add timeout/runout info
                        if timeout_seconds:
                            self.snapshot.info(f"{job_name}() will timeout after {schedule_timeout}")
                        if schedule_runout:
                            self.snapshot.info(f"{job_name}() will run at most {schedule_runout} times")

                        self.snapshot.info("Waiting for schedule trigger...")
                        last_run_time = None
                        run_count = 0
                        start_time = time.time()

                        while True:
                            now = datetime.now()
                            current_time = time.time()

                            # Check termination conditions
                            if timeout_seconds and (current_time - start_time) > timeout_seconds:
                                self.snapshot.info(f"Schedule timeout reached after {schedule_timeout}")
                                break

                            if schedule_runout and run_count >= schedule_runout:
                                self.snapshot.info(f"Schedule runout reached after {run_count} executions")
                                break

                            # For continuous mode, skip time matching logic
                            if not is_continuous:
                                # Prevent duplicate execution in same second
                                if now.strftime("%H:%M:%S") == last_run_time:
                                    time.sleep(0.5)
                                    continue

                                is_match = (h_str == "*" or now.hour == int(h_str)) and (m_str == "*" or now.minute == int(m_str)) and (s_str == "*" or now.second == int(s_str))
                            else:
                                is_match = True  # Always run in continuous mode

                            if is_match:
                                self.snapshot.info("")
                                self.snapshot.info(f"Executing scheduled {job_name}() at {now.strftime('%Y-%m-%d %H:%M:%S')}")
                                self.jobs_meta_kwargs[job_name]["finished"] = False

                                execute_job_sequence()
                                run_count += 1
                                last_run_time = now.strftime("%H:%M:%S")

                                # In continuous mode, skip the sleep and go straight to next execution
                                if is_continuous:
                                    continue

                            time.sleep(0.2)  # Check time frequently

                    except KeyboardInterrupt:
                        self.snapshot.info(f"Scheduler for {job_name}() interrupted.")

                # Execution controller
                def run_job():
                    if schedule:
                        run_scheduled_job()
                    else:
                        execute_job_sequence()

                # Master concurrency decision
                if master_concurrent:
                    self.snapshot.info(f"{job_name}() running concurrently in background...")
                    thread = threading.Thread(target=run_job)
                    thread.daemon = True
                    thread.start()
                else:
                    run_job()

            return wrapper

        return decorator

    def _snapshot_reset(self, confirm) -> Optional[Any]:
        self.snapshot._name = self.script_name
        if self._original_snapshot_reset is not None:
            return self._original_snapshot_reset(confirm)

    def _snapshot_setup(self, *args, **kwargs) -> Optional[Any]:
        if "name" in kwargs:
            del kwargs["name"]

        if self._original_snapshot_setup is not None:
            return self._original_snapshot_setup(*args, **kwargs, name=self.script_name)

    def save_meta(self, **kwargs) -> None:
        """
        Sets all the __init__ arguments as attribute to the class instance.
        """

        # for key, value in self.meta_kwargs.items():
        #     setattr(self, key, value)

        # for key, value in kwargs.items():
        #     setattr(self, key, value)
        #     self.meta_kwargs[key] = getattr(self, key)

        frame = inspect.currentframe().f_back  # type: ignore
        init_kwargs = inspect.getargvalues(frame).locals  # type: ignore
        for key, value in init_kwargs.items():
            if key != "self" and key != "__class__":
                setattr(self, key, value)

        self.snapshot.debug(f"Initialized meta config: {init_kwargs}")
        self.meta_saved = True
        # self.jobs_meta_kwargs["current_epoch"], self.current_batch_idx,

    def register_checkpoint_config(self, module_name, module, get_params, update_params) -> None:
        setattr(self, module_name, module)
        if self.use_config == "custom":
            self.registered_checkpoint_configs_list.append(
                {
                    "module_name": module_name,
                    "custom_config": get_params(module),
                    "update_config_fn": update_params,
                }
            )
            self.snapshot.debug(f"Registered checkpoint config for '{module_name}'")

    def __call__(self):  # -> Any:
        raise NotImplementedError

    def __init_subclass__(cls, **kwargs) -> None:
        """Finalize configuration"""

        if getattr(cls, "_opencrate_subclass_initialized", False):
            super().__init_subclass__(**kwargs)
            return

        cls._opencrate_subclass_initialized = True

        original_init = cls.__init__

        def new_init(self, *args, **kwargs):
            # self.script_name = cls.__module__.split(".")[-1]
            self.script_name = cls.__name__
            self.script_name = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", self.script_name)
            self.script_name = re.sub(r"([a-z])([A-Z])", r"\1_\2", self.script_name).lower()

            # Setup snapshot
            _configuration.snapshot = self.snapshot
            self._original_snapshot_reset = self.snapshot.reset
            self._original_snapshot_setup = self.snapshot.setup
            self.snapshot.reset = self._snapshot_reset
            self.snapshot.setup = self._snapshot_setup
            self.snapshot.setup(
                start=self.start,
                tag=self.tag,
                replace=self.replace,
                log_level=self.log_level,
            )

            # Check if checkpoint exists
            if self.finetune is not None:
                prefix = "Finetuning"
            else:
                meta_list = glob(os.path.join(self.snapshot.path.checkpoint(), "meta_*.json"))
                checkpoint_exists = len(meta_list) > 0
                prefix = "Resuming" if checkpoint_exists else "Creating"
            config_path = f"config/{self.script_name}:{self.use_config}.yml"

            # Determine if we should use existing config or create default
            use_existing_config = False
            configs = glob(os.path.join(self.snapshot.dir_path, "*.yml"))

            if self.use_config == "resume":
                if len(configs) == 0:
                    message = "\n\nCannot use `config='resume'` when creating a new snapshot, as there is no existing snapshot to resume."
                    if os.path.exists("config"):
                        available_configs = [os.path.splitext(name)[0].split(":")[-1] for name in os.listdir("config")]
                        if available_configs:
                            message += f"\nPlease use `config='default'` or one of the available configs: {', '.join(available_configs)}."
                    else:
                        message += "\nPlease use `config='default'` to create an initial configuration."
                    raise AssertionError(message)

                assert len(configs) == 1, (
                    f"\n\nMultiple config files found in the snapshot {self.snapshot.dir_path}.\
                        nThere must be only one config present in the snapshot to get selected for resuming the pipeline.\n"
                )
                config_name = os.path.splitext(os.path.basename(configs[0]))[0]
            else:
                config_name = f"{self.script_name}:{self.use_config}"
            # if self.use_config not in ("default", "resume") and os.path.isfile(
            if self.use_config != "resume" and os.path.isfile(config_path):
                # Use custom config file if it exists
                _configuration.read(config_name)

                _configuration.write(config_name)
                if prefix != "Finetuning":
                    _configuration.display(f"[bold]{prefix}[/bold] [bold]{self.snapshot.version_name}[/bold] with {self.use_config} config")
                else:
                    if self.finetune is not None:
                        if self.finetune == "dev":
                            finetune_from_version = f"{self.finetune}"
                        else:
                            finetune_from_version = f"v{self.finetune}"
                        if self.finetune_tag:
                            finetune_from_version = f"{finetune_from_version}:{self.finetune_tag}"

                    _configuration.display(f"[bold]{prefix}[/bold] from [bold]{finetune_from_version}[/bold] to [bold]{self.snapshot.version_name}[/bold] with custom config")
                use_existing_config = True
            elif self.use_config == "resume" and self.start not in ("reset", "new"):
                # Use resume config from checkpoint
                _configuration.read(config_name, load_from_use_version=True)
                # _configuration.write(
                #     f"{self.script_name}:{self.use_config}", replace_config=True
                # )

                _configuration.write(config_name)
                _configuration.display(f"[bold]{prefix} {self.snapshot.version_name}[/bold] with resume config")
                use_existing_config = True
            else:
                if os.path.exists("config") and self.use_config != "default":
                    available_config_names = [os.path.splitext(name)[0].split(":")[-1] for name in os.listdir("config")]
                    if len(available_config_names) == 1:
                        assert self.use_config in available_config_names, (
                            f"\n\nNo config found with name '{self.use_config}'.\nThe only available config in your `config/` folder is '{available_config_names[0]}'.\n"
                        )
                    else:
                        assert self.use_config in available_config_names, (
                            f"\n\nNo config found with name '{self.use_config}'.\nAvailable config names in your `config/` folder are: {', '.join(available_config_names)}.\n"
                        )
                else:
                    assert self.use_config == "default", (
                        f"\n\nNo config found with name '{self.use_config}' as no 'config' folder exists.\nYou must first create a default config by using `config='default'`.\n"
                    )

            # Initialize with appropriate config
            _configuration.config_eval_timeout = self.config_eval_timeout
            _configuration.config_eval_start = time.perf_counter()

            decorated_init = config()(original_init)
            decorated_init(self, *args, **kwargs)

            _configuration.opencrate_init_done = True

            # If we didn't use an existing config or if starting new with resume config,
            # write the default config
            if not use_existing_config or (self.start == "new" and self.use_config == "resume"):
                _configuration.write(f"{self.script_name}:{self.use_config}", replace_config=True)
                if self.finetune:
                    if self.finetune is not None:
                        if self.finetune == "dev":
                            finetune_from_version = f"{self.finetune}"
                        else:
                            finetune_from_version = f"v{self.finetune}"
                        if self.finetune_tag:
                            finetune_from_version = f"{finetune_from_version}:{self.finetune_tag}"

                    _configuration.display(f"[bold]{prefix}[/bold] from [bold]{finetune_from_version}[/bold] to [bold]{self.snapshot.version_name}[/bold] with default config")
                else:
                    config_type = "default"
                    if not use_existing_config and self.use_config == "custom":
                        config_type += f" (as no custom config found at '{config_path}')"
                    _configuration.display(f"[bold]{prefix}[/bold] [bold]{self.snapshot.version_name}[/bold] with {config_type} config")

            # get list of all methods that start with "save_" prefix
            save_methods_names = [method_name for method_name in dir(self) if method_name.startswith("save_")]
            load_methods_names = [method_name for method_name in dir(self) if method_name.startswith("load_")]

            for save_method_name in save_methods_names:
                setattr(
                    self,
                    save_method_name,
                    self._save_checkpoint_decorator(getattr(self, save_method_name)),
                )

            for load_method_name in load_methods_names:
                setattr(
                    self,
                    load_method_name,
                    self._load_checkpoint_decorator(getattr(self, load_method_name)),
                )

        setattr(cls, "__init__", new_init)
        super().__init_subclass__(**kwargs)
        # self.save_meta() # have this here run by default, make it optional from the user side, user will only need to call this if they need to add new meta variables

    def _save_checkpoint_decorator(self, func) -> Callable[..., None]:
        def wrapper(*args, **kwargs):
            # self.snapshot.checkpoint(
            #     {key: getattr(self, key) for key in self.meta_kwargs}, "meta.json"
            # )

            job_name = func.__name__.replace("save_", "")
            job_ckpt = self.jobs_meta_kwargs[job_name]
            if "batch_progress" in job_ckpt:
                batch_progress = job_ckpt["batch_progress"]
                del job_ckpt["batch_progress"]
                job_ckpt_copy = deepcopy(job_ckpt)
                self.jobs_meta_kwargs[job_name]["batch_progress"] = batch_progress
                self.snapshot.checkpoint(job_ckpt_copy, f"meta_{job_name}.json")
                del job_ckpt_copy
            else:
                self.snapshot.checkpoint(job_ckpt, f"meta_{job_name}.json")
            func(*args, **kwargs)
            self.snapshot.debug("Saved checkpoint successfully!")

            # job_name = func.__name__.replace("save_", "")
            # job_ckpt = deepcopy(self.jobs_meta_kwargs[job_name])
            # del job_ckpt["batch_progress"]  # remove batch progress from checkpoint
            # # del job_ckpt["batch_progress"]  # remove batch progress from checkpoint
            # self.snapshot.checkpoint(job_ckpt, f"meta_{job_name}.json")
            # del job_ckpt
            # func(*args, **kwargs)
            # self.snapshot.debug("Saved checkpoint successfully!")

        return wrapper

    def _load_checkpoint_decorator(self, func) -> Callable[..., Any]:
        def wrapper(*args, **kwargs):
            try:
                if self.finetune is not None:
                    new_version_name = self.snapshot.version_name
                    new_version = self.snapshot.version
                    new_tag = self.snapshot.tag
                    if self.finetune == "dev":
                        self.snapshot.version_name = f"{self.finetune}"
                    else:
                        self.snapshot.version_name = f"v{self.finetune}"
                    if self.finetune_tag:
                        self.snapshot.version_name = f"{self.snapshot.version_name}:{self.finetune_tag}"
                        self.snapshot.tag = self.finetune_tag
                    else:
                        self.snapshot.tag = None

                    self.snapshot.version = self.finetune
                    self.snapshot.debug(f"Loading checkpoint for finetuning from '{self.snapshot.version_name}'")
                else:
                    job_name = func.__name__.replace("load_", "")
                    meta_path = self.snapshot.path.checkpoint(f"meta_{job_name}.json", check_exists=False)
                    # if self.finetune is not None:
                    #     assert os.path.isfile(
                    #         meta_path
                    #     ), f"\n\nUnable to find checkpoint for finetuning at '{meta_path}'\n"
                    if not os.path.isfile(meta_path):
                        self.snapshot.debug(f"Skipping checkpoint loading, '{meta_path}' not found")
                        return  # handle this return better, right now it just skips the job if the meta file is not found
                    # self.snapshot.debug(f"Loading meta variables from '{meta_path}'")
                    try:
                        assert _has_torch, "\n\nPyTorch is not installed. Please install PyTorch to load a checkpoint.\n\n"
                        loaded_job_meta_kwargs = torch.load(meta_path, weights_only=False)
                        # new_meta_kwargs = {}
                        # for key, value in meta.items():
                        #     setattr(self, key, value)
                        #     assert key in self.jobs_meta_kwargs[job_name], (
                        #         f"Failed to load meta variables, `{key}` not found in this checkpoint."
                        #     )
                        #     new_meta_kwargs[key] = value
                        # if not (
                        #     len(self.jobs_meta_kwargs[job_name]) == len(new_meta_kwargs)
                        # ):
                        #     unknown_keys = list(
                        #         set(self.jobs_meta_kwargs[job_name].keys())
                        #         - set(new_meta_kwargs.keys())
                        #     )
                        #     raise AssertionError(
                        #         f"Failed to load meta variables, '{', '.join(unknown_keys)}' not found in this checkpoint."
                        #     )
                        self.jobs_meta_kwargs[job_name] = loaded_job_meta_kwargs
                        self.snapshot.debug(f"Loaded meta variables from '{meta_path}'")
                    except Exception as e:
                        self.snapshot.exception(f"Failed to load meta variables > {e}")

                func(*args, **kwargs)

                if self.finetune is not None:
                    self.snapshot.version_name = new_version_name
                    self.snapshot.version = new_version
                    self.snapshot.tag = new_tag

                if self.use_config == "custom":
                    for module in self.registered_checkpoint_configs_list:
                        module["update_config_fn"](
                            getattr(self, module["module_name"]),
                            module["custom_config"],
                        )
                        self.snapshot.debug(f"Updated checkpoint config for '{module['module_name']}' to '{module['custom_config']}'")

                self.snapshot.debug("Loaded checkpoint successfully!")
            except Exception as e:
                msg = str(e).replace("\n", "")
                raise CheckpointLoadException(f"Failed to load checkpoint. {msg}")

        return wrapper

    def epoch_progress(self, num_epochs, title: str = "Epoch") -> Generator[Any, Any, None]:
        job_name = inspect.stack()[1].function

        assert self.meta_saved, "Meta variables not saved. Please call `save_meta()` in `__init__` method."

        self.jobs_meta_kwargs[job_name]["epoch_title"] = title
        self.num_epochs = num_epochs

        do_profile = self.jobs_meta_kwargs[job_name]["profile"] is not None

        for self.jobs_meta_kwargs[job_name]["current_epoch"] in range(self.jobs_meta_kwargs[job_name]["start_epoch"], self.num_epochs):
            if do_profile:
                start_time = time.perf_counter()

            yield self.jobs_meta_kwargs[job_name]["current_epoch"]

            if do_profile:
                self.jobs_meta_kwargs[job_name]["profile"]["epoch"].append(time.perf_counter() - start_time)

            for fig_title, fig in self.jobs_meta_kwargs[job_name]["batch_progress"].plot_accumulated_metrics(epoch=f"{self.jobs_meta_kwargs[job_name]['current_epoch'] + 1}"):
                # Save epoch-specific version if epoch number is available
                fig_title = fig_title.replace(", ", "_")
                fig_path = f"monitored/{job_name}({fig_title})[epochs].jpg"
                self.snapshot.figure(fig, fig_path)
                plt.subplots_adjust(left=0.08, right=0.92, top=0.94, bottom=0.06)
                plt.close(fig)
            plt.close("all")
            self.jobs_meta_kwargs[job_name]["start_epoch"] += 1
        # TODO: consider automating and standardizing some of such common variable names in ML projects

    def batch_progress(self, dataloader, title="Batch") -> Generator[Tuple[Any, Any, Any], Any, None]:
        job_name = inspect.stack()[1].function

        assert self.meta_saved, "Meta variables not saved. Please call `save_meta()` in `__init__` method."

        if self.jobs_meta_kwargs[job_name]["is_resuming"]:
            self.jobs_meta_kwargs[job_name]["start_batch_idx"] += 1

        metrics_are_not_resumed = True
        if self.jobs_meta_kwargs[job_name]["epoch_title"] is not None:
            epoch_title = f"{self.jobs_meta_kwargs[job_name]['epoch_title']}({self.jobs_meta_kwargs[job_name]['current_epoch'] + 1}/{self.num_epochs})"
        else:
            epoch_title = ""

        do_profile = self.jobs_meta_kwargs[job_name]["profile"] is not None

        for batch_idx, batch, self.jobs_meta_kwargs[job_name]["batch_progress"] in progress(
            dataloader,
            title=epoch_title,
            step=title,
            step_start=self.jobs_meta_kwargs[job_name]["start_batch_idx"],
            job_name=job_name,
        ):
            if metrics_are_not_resumed:
                for metric_name, metric_values in self.jobs_meta_kwargs[job_name]["metrics"].items():
                    self.jobs_meta_kwargs[job_name]["batch_progress"].metrics[metric_name] = metric_values
                for metric_name, metrics_accumulated_values in self.jobs_meta_kwargs[job_name]["metrics_accumulated"].items():
                    self.jobs_meta_kwargs[job_name]["batch_progress"].metrics_accumulated[metric_name] = metrics_accumulated_values
                metrics_are_not_resumed = False

            (
                self.jobs_meta_kwargs[job_name]["is_resuming"],
                self.jobs_meta_kwargs[job_name]["start_batch_idx"],
            ) = (
                True,
                batch_idx,
            )
            if do_profile:
                start_time = time.perf_counter()

            yield batch_idx, batch, self.jobs_meta_kwargs[job_name]["batch_progress"]

            if do_profile:
                self.jobs_meta_kwargs[job_name]["profile"]["batch"].append(time.perf_counter() - start_time)
            (
                self.jobs_meta_kwargs[job_name]["is_resuming"],
                self.jobs_meta_kwargs[job_name]["start_batch_idx"],
            ) = (
                False,
                self.jobs_meta_kwargs[job_name]["start_batch_idx"] + 1,
            )

        # if hasattr(self, "_custom_batch_progress"):
        if self.jobs_meta_kwargs[job_name]["batch_progress"] is not None:
            for (
                metric_name,
                metric_values,
            ) in self.jobs_meta_kwargs[job_name]["batch_progress"].metrics.items():
                self.jobs_meta_kwargs[job_name]["metrics"][metric_name] = metric_values

            for (
                metric_name,
                metrics_accumulated_values,
            ) in self.jobs_meta_kwargs[job_name]["batch_progress"].metrics_accumulated.items():
                self.jobs_meta_kwargs[job_name]["metrics_accumulated"][metric_name] = metrics_accumulated_values

            self.jobs_meta_kwargs[job_name]["start_batch_idx"] = 0

    @classmethod
    def launch(cls, *args, **kwargs) -> Any:
        from ..cli.environment import launch

        workflow: Union[str, Type[OpenCrate]] = cls.__module__.split(".")[-1]
        if isinstance(workflow, str) and (workflow == "__main__" or "." not in workflow):
            workflow = cls
        if "workflow" in kwargs:
            del kwargs["workflow"]
        return launch(*args, **kwargs, workflow=workflow)

    def __str__(self) -> str:
        cls_name = type(self).__name__
        details = [
            f"version={self.snapshot.version_name}",
            f"tag={self.tag}",
            f"replace={self.replace}",
            f"config={self.use_config}",
            f"finetune={self.finetune}",
            f"finetune_tag={self.finetune_tag}",
        ]
        return "{}(\n    {},\n)".format(cls_name, ",\n    ".join(details))

    def __repr__(self) -> str:
        return self.__str__()

__init_subclass__(**kwargs)

Finalize configuration

Source code in opencrate/core/opencrate.py
def __init_subclass__(cls, **kwargs) -> None:
    """Finalize configuration"""

    if getattr(cls, "_opencrate_subclass_initialized", False):
        super().__init_subclass__(**kwargs)
        return

    cls._opencrate_subclass_initialized = True

    original_init = cls.__init__

    def new_init(self, *args, **kwargs):
        # self.script_name = cls.__module__.split(".")[-1]
        self.script_name = cls.__name__
        self.script_name = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", self.script_name)
        self.script_name = re.sub(r"([a-z])([A-Z])", r"\1_\2", self.script_name).lower()

        # Setup snapshot
        _configuration.snapshot = self.snapshot
        self._original_snapshot_reset = self.snapshot.reset
        self._original_snapshot_setup = self.snapshot.setup
        self.snapshot.reset = self._snapshot_reset
        self.snapshot.setup = self._snapshot_setup
        self.snapshot.setup(
            start=self.start,
            tag=self.tag,
            replace=self.replace,
            log_level=self.log_level,
        )

        # Check if checkpoint exists
        if self.finetune is not None:
            prefix = "Finetuning"
        else:
            meta_list = glob(os.path.join(self.snapshot.path.checkpoint(), "meta_*.json"))
            checkpoint_exists = len(meta_list) > 0
            prefix = "Resuming" if checkpoint_exists else "Creating"
        config_path = f"config/{self.script_name}:{self.use_config}.yml"

        # Determine if we should use existing config or create default
        use_existing_config = False
        configs = glob(os.path.join(self.snapshot.dir_path, "*.yml"))

        if self.use_config == "resume":
            if len(configs) == 0:
                message = "\n\nCannot use `config='resume'` when creating a new snapshot, as there is no existing snapshot to resume."
                if os.path.exists("config"):
                    available_configs = [os.path.splitext(name)[0].split(":")[-1] for name in os.listdir("config")]
                    if available_configs:
                        message += f"\nPlease use `config='default'` or one of the available configs: {', '.join(available_configs)}."
                else:
                    message += "\nPlease use `config='default'` to create an initial configuration."
                raise AssertionError(message)

            assert len(configs) == 1, (
                f"\n\nMultiple config files found in the snapshot {self.snapshot.dir_path}.\
                    nThere must be only one config present in the snapshot to get selected for resuming the pipeline.\n"
            )
            config_name = os.path.splitext(os.path.basename(configs[0]))[0]
        else:
            config_name = f"{self.script_name}:{self.use_config}"
        # if self.use_config not in ("default", "resume") and os.path.isfile(
        if self.use_config != "resume" and os.path.isfile(config_path):
            # Use custom config file if it exists
            _configuration.read(config_name)

            _configuration.write(config_name)
            if prefix != "Finetuning":
                _configuration.display(f"[bold]{prefix}[/bold] [bold]{self.snapshot.version_name}[/bold] with {self.use_config} config")
            else:
                if self.finetune is not None:
                    if self.finetune == "dev":
                        finetune_from_version = f"{self.finetune}"
                    else:
                        finetune_from_version = f"v{self.finetune}"
                    if self.finetune_tag:
                        finetune_from_version = f"{finetune_from_version}:{self.finetune_tag}"

                _configuration.display(f"[bold]{prefix}[/bold] from [bold]{finetune_from_version}[/bold] to [bold]{self.snapshot.version_name}[/bold] with custom config")
            use_existing_config = True
        elif self.use_config == "resume" and self.start not in ("reset", "new"):
            # Use resume config from checkpoint
            _configuration.read(config_name, load_from_use_version=True)
            # _configuration.write(
            #     f"{self.script_name}:{self.use_config}", replace_config=True
            # )

            _configuration.write(config_name)
            _configuration.display(f"[bold]{prefix} {self.snapshot.version_name}[/bold] with resume config")
            use_existing_config = True
        else:
            if os.path.exists("config") and self.use_config != "default":
                available_config_names = [os.path.splitext(name)[0].split(":")[-1] for name in os.listdir("config")]
                if len(available_config_names) == 1:
                    assert self.use_config in available_config_names, (
                        f"\n\nNo config found with name '{self.use_config}'.\nThe only available config in your `config/` folder is '{available_config_names[0]}'.\n"
                    )
                else:
                    assert self.use_config in available_config_names, (
                        f"\n\nNo config found with name '{self.use_config}'.\nAvailable config names in your `config/` folder are: {', '.join(available_config_names)}.\n"
                    )
            else:
                assert self.use_config == "default", (
                    f"\n\nNo config found with name '{self.use_config}' as no 'config' folder exists.\nYou must first create a default config by using `config='default'`.\n"
                )

        # Initialize with appropriate config
        _configuration.config_eval_timeout = self.config_eval_timeout
        _configuration.config_eval_start = time.perf_counter()

        decorated_init = config()(original_init)
        decorated_init(self, *args, **kwargs)

        _configuration.opencrate_init_done = True

        # If we didn't use an existing config or if starting new with resume config,
        # write the default config
        if not use_existing_config or (self.start == "new" and self.use_config == "resume"):
            _configuration.write(f"{self.script_name}:{self.use_config}", replace_config=True)
            if self.finetune:
                if self.finetune is not None:
                    if self.finetune == "dev":
                        finetune_from_version = f"{self.finetune}"
                    else:
                        finetune_from_version = f"v{self.finetune}"
                    if self.finetune_tag:
                        finetune_from_version = f"{finetune_from_version}:{self.finetune_tag}"

                _configuration.display(f"[bold]{prefix}[/bold] from [bold]{finetune_from_version}[/bold] to [bold]{self.snapshot.version_name}[/bold] with default config")
            else:
                config_type = "default"
                if not use_existing_config and self.use_config == "custom":
                    config_type += f" (as no custom config found at '{config_path}')"
                _configuration.display(f"[bold]{prefix}[/bold] [bold]{self.snapshot.version_name}[/bold] with {config_type} config")

        # get list of all methods that start with "save_" prefix
        save_methods_names = [method_name for method_name in dir(self) if method_name.startswith("save_")]
        load_methods_names = [method_name for method_name in dir(self) if method_name.startswith("load_")]

        for save_method_name in save_methods_names:
            setattr(
                self,
                save_method_name,
                self._save_checkpoint_decorator(getattr(self, save_method_name)),
            )

        for load_method_name in load_methods_names:
            setattr(
                self,
                load_method_name,
                self._load_checkpoint_decorator(getattr(self, load_method_name)),
            )

    setattr(cls, "__init__", new_init)
    super().__init_subclass__(**kwargs)

save_meta(**kwargs)

Sets all the init arguments as attribute to the class instance.

Source code in opencrate/core/opencrate.py
def save_meta(self, **kwargs) -> None:
    """
    Sets all the __init__ arguments as attribute to the class instance.
    """

    # for key, value in self.meta_kwargs.items():
    #     setattr(self, key, value)

    # for key, value in kwargs.items():
    #     setattr(self, key, value)
    #     self.meta_kwargs[key] = getattr(self, key)

    frame = inspect.currentframe().f_back  # type: ignore
    init_kwargs = inspect.getargvalues(frame).locals  # type: ignore
    for key, value in init_kwargs.items():
        if key != "self" and key != "__class__":
            setattr(self, key, value)

    self.snapshot.debug(f"Initialized meta config: {init_kwargs}")
    self.meta_saved = True