Skip to content

DagsHubCallback fails on Windows #621

@mattiacurri

Description

@mattiacurri

dagshub.init() use an URL as MLFLOW_TRACKING_URI, so when the Dagshub integrator uses os.sep it breaks on Windows as it doesn't use the correct separator. On Linux no problems, on MacOS I don't have the possibility to try it, but the separator should be the same of Linux.

import dagshub
from datasets import Dataset
import mlflow
from setfit import SetFitModel, Trainer, TrainingArguments

if __name__ == "__main__":
    dagshub.init(repo_owner="XXX", repo_name="YYY", mlflow=True)
    
    mlflow.set_experiment("issue_dagshub")
    
    train_data = Dataset.from_dict({
        "text": ["example 1", "example 2", "example 3"],
        "label": [[1, 0], [0, 1], [1, 1]]
    })
    
    model = SetFitModel.from_pretrained(
        "sentence-transformers/paraphrase-MiniLM-L3-v2",
        multi_target_strategy="multi-output"
    )
    
    with mlflow.start_run(run_name="minimal-test-setfit"):
        args = TrainingArguments(
            num_epochs=1,
            batch_size=2,
            num_iterations=1,
            report_to="mlflow"
        )
        
        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=train_data,
            column_mapping={"text": "text", "label": "label"}
        )
        
        trainer.train()
Accessing as XXX
Initialized MLflow to track repo "XXX/YYY"
Repository XXX/YYY initialized!
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\mlflow\store\tracking\rest_store.py:211: DeprecationWarning: label() is deprecated. Use is_required() or is_repeated() instead.
  req_body = message_to_json(
Applying column mapping to the training dataset
Map: 100%|████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 1367.56 examples/s]
***** Running training *****
  Num unique pairs = 6
  Batch size = 2
  Num epochs = 1
🏃 View run minimal-test-setfit at: X
🧪 View experiment at: X
C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\mlflow\store\tracking\rest_store.py:182: DeprecationWarning: label() is deprecated. Use is_required() or is_repeated() instead.
  req_body = message_to_json(
Traceback (most recent call last):
  File "C:\Users\X\Desktop\Progetti\my-first-repo\minimal_issue.py", line 37, in <module>
    trainer.train()  # Issue occurs here
    ^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 531, in train  
    self.train_embeddings(*full_parameters, args=args)
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 582, in train_embeddings
    self.st_trainer.train()
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\trainer.py", line 2325, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\trainer.py", line 2573, in _inner_training_loop
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\trainer_callback.py", line 506, in on_train_begin
    return self.call_event("on_train_begin", args, state, control)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 97, in <lambda>
    self.callback_handler.call_event = lambda *args, **kwargs: overwritten_call_event(
                                                               ^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\setfit\trainer.py", line 74, in overwritten_call_event
    result = getattr(callback, event)(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\integrations\integration_utils.py", line 1489, in on_train_begin
    self.setup(args, state, model)
  File "C:\Users\X\Desktop\Progetti\my-first-repo\.venv\Lib\site-packages\transformers\integrations\integration_utils.py", line 1569, in setup
    owner=self.remote.split(os.sep)[-2],
          ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
IndexError: list index out of range

I'm using setfit==1.1.3.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions