#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass, field
from typing import Optional
from transformers import TFTrainingArguments, TrainingArguments
[docs]@dataclass
class T4RecTrainingArguments(TrainingArguments):
"""
Class that inherits HF TrainingArguments and add on top of it arguments needed for
session-based and sequential-based recommendation
Parameters
----------
shuffle_buffer_size:
validate_every: Optional[int], int
Run validation set every this epoch.
-1 means no validation is used
by default -1
eval_on_test_set:
eval_steps_on_train_set:
predict_top_k: Option[int], int
Truncate recommendation list to the highest top-K predicted items,
(do not affect evaluation metrics computation),
this parameter is specific to NextItemPredictionTask.
by default 0
log_predictions : Optional[bool], bool
log predictions, labels and metadata features each --compute_metrics_each_n_steps
(for test set).
by default False
log_attention_weights : Optional[bool], bool
Logs the inputs and attention weights
each --eval_steps (only test set)"
by default False
learning_rate_num_cosine_cycles_by_epoch : Optional[int], int
Number of cycles for by epoch when --lr_scheduler_type = cosine_with_warmup.
The number of waves in the cosine schedule
(e.g. 0.5 is to just decrease from the max value to 0, following a half-cosine).
by default 1.25
experiments_group: Optional[str], str
Name of the Experiments Group, for organizing job runs logged on W&B
by default "default"
"""
max_sequence_length: Optional[int] = field(
default=None,
metadata={"help": "maximum length of sequence"},
)
shuffle_buffer_size: int = field(
default=0,
metadata={
"help": "Number of samples to keep in the buffer for shuffling."
"shuffle_buffer_size=0 means no shuffling"
},
)
data_loader_engine: str = field(
default="merlin",
metadata={
"help": "Parquet data loader engine. "
"'merlin': GPU-accelerated parquet data loader from Merlin, 'pyarrow': read whole parquet into memory."
},
)
eval_on_test_set: bool = field(
default=False,
metadata={"help": "Evaluate on test set (by default, evaluates on the validation set)."},
)
eval_steps_on_train_set: int = field(
default=20,
metadata={"help": "Number of steps to evaluate on train set (which is usually large)"},
)
predict_top_k: int = field(
default=0,
metadata={
"help": "Truncate recommendation list to the highest top-K predicted items (do not affect evaluation metrics computation), "
"this parameter is specific to NextItemPredictionTask."
},
)
learning_rate_num_cosine_cycles_by_epoch: float = field(
default=1.25,
metadata={
"help": "Number of cycles for by epoch when --learning_rate_schedule = cosine_with_warmup."
"The number of waves in the cosine schedule "
"(e.g. 0.5 is to just decrease from the max value to 0, following a half-cosine)."
},
)
log_predictions: bool = field(
default=False,
metadata={
"help": "Logs predictions, labels and metadata features each --compute_metrics_each_n_steps (for test set)."
},
)
compute_metrics_each_n_steps: int = field(
default=1,
metadata={"help": "Log metrics each n steps (for train, validation and test sets)"},
)
experiments_group: str = field(
default="default",
metadata={"help": "Name of the Experiments Group, for organizing job runs logged on W&B"},
)
@property
def place_model_on_device(self):
"""
Override the method to allow running training on cpu
"""
if self.device.type == "cuda":
return True
return False
[docs]class T4RecTrainingArgumentsTF(T4RecTrainingArguments, TFTrainingArguments):
"""
Prepare Training arguments for TFTrainer,
Inherit arguments from T4RecTrainingArguments and TFTrainingArguments
"""
pass