-
Notifications
You must be signed in to change notification settings - Fork 439
/
Copy pathlanguage_model_arguments.py
71 lines (69 loc) · 2.45 KB
/
language_model_arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from dataclasses import dataclass, field
@dataclass
class LanguageModelHandlerArguments:
lm_model_name: str = field(
default="HuggingFaceTB/SmolLM-360M-Instruct",
metadata={
"help": "The pretrained language model to use. Default is 'HuggingFaceTB/SmolLM-360M-Instruct'."
},
)
lm_device: str = field(
default="cuda",
metadata={
"help": "The device type on which the model will run. Default is 'cuda' for GPU acceleration."
},
)
lm_torch_dtype: str = field(
default="float16",
metadata={
"help": "The PyTorch data type for the model and input tensors. One of `float32` (full-precision), `float16` or `bfloat16` (both half-precision)."
},
)
user_role: str = field(
default="user",
metadata={
"help": "Role assigned to the user in the chat context. Default is 'user'."
},
)
init_chat_role: str = field(
default="system",
metadata={
"help": "Initial role for setting up the chat context. Default is 'system'."
},
)
init_chat_prompt: str = field(
default="You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise responses of less than 20 words.",
metadata={
"help": "The initial chat prompt to establish context for the language model. Default is 'You are a helpful AI assistant.'"
},
)
lm_gen_max_new_tokens: int = field(
default=128,
metadata={
"help": "Maximum number of new tokens to generate in a single completion. Default is 128."
},
)
lm_gen_min_new_tokens: int = field(
default=0,
metadata={
"help": "Minimum number of new tokens to generate in a single completion. Default is 0."
},
)
lm_gen_temperature: float = field(
default=0.0,
metadata={
"help": "Controls the randomness of the output. Set to 0.0 for deterministic (repeatable) outputs. Default is 0.0."
},
)
lm_gen_do_sample: bool = field(
default=False,
metadata={
"help": "Whether to use sampling; set this to False for deterministic outputs. Default is False."
},
)
chat_size: int = field(
default=2,
metadata={
"help": "Number of interactions assitant-user to keep for the chat. None for no limitations."
},
)