Skip to content

Commit 40e3968

Browse files
Fine Tune Training (#1376)
* get models for web version * fine-tune front-end and web command arguments * updating user trained models * update run trianing for tuneModel arguments * electorn support * pydantic type updates * fix arguments for training * support running locally on single gpu pc's by default * update fine tune model web display * Update client/platform/web-girder/views/RunTrainingMenu.vue Co-authored-by: Michael Nagler <[email protected]> * Update server/dive_server/crud_rpc.py Co-authored-by: Michael Nagler <[email protected]> * adressing comments --------- Co-authored-by: Michael Nagler <[email protected]>
1 parent 1d603d5 commit 40e3968

File tree

16 files changed

+310
-21
lines changed

16 files changed

+310
-21
lines changed

client/dive-common/apispec.ts

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,16 @@ interface Category {
4545
}
4646

4747
interface TrainingConfigs {
48-
configs: string[];
49-
default: string;
48+
training: {
49+
configs: string[];
50+
default: string;
51+
};
52+
models: Record<string, {
53+
name: string;
54+
type: string;
55+
path?: string;
56+
folderId?: string;
57+
}>;
5058
}
5159

5260
type Pipelines = Record<string, Category>;
@@ -155,6 +163,12 @@ interface Api {
155163
config: string,
156164
annotatedFramesOnly: boolean,
157165
labelText?: string,
166+
fineTuneModel?: {
167+
name: string;
168+
type: string;
169+
path?: string;
170+
folderId?: string;
171+
},
158172
): Promise<unknown>;
159173

160174
loadMetadata(datasetId: string): Promise<DatasetMeta>;

client/platform/desktop/backend/native/common.ts

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,14 +423,36 @@ async function getPipelineList(settings: Settings): Promise<Pipelines> {
423423
return ret;
424424
}
425425

426+
// Function to recursively traverse a directory and collect files with specified extensions
427+
function getFilesWithExtensions(dir: string, extensions: string[], fileList: string[] = []) {
428+
const files = fs.readdirSync(dir);
429+
430+
files.forEach((file) => {
431+
const filePath = npath.join(dir, file);
432+
const fileStat = fs.statSync(filePath);
433+
434+
if (fileStat.isDirectory()) {
435+
fileList.concat(getFilesWithExtensions(filePath, extensions, fileList));
436+
} else {
437+
const fileExtension = npath.extname(file).toLowerCase();
438+
if (extensions.includes(fileExtension)) {
439+
fileList.push(filePath);
440+
}
441+
}
442+
});
443+
444+
return fileList;
445+
}
446+
426447
/**
427448
* get training configurations
428449
*/
429450
async function getTrainingConfigs(settings: Settings): Promise<TrainingConfigs> {
430451
const pipelinePath = npath.join(settings.viamePath, 'configs/pipelines');
431-
const defaultTrainingConfiguration = 'train_detector_default.conf';
432-
const allowedPatterns = /train_.*\.conf$/;
433-
const disallowedPatterns = /.*(_nf|\.continue)\.viame_csv\.conf$|.*\.kw18\.conf$|.*\.habcam\.conf$|.*\.continue\.conf$/;
452+
const defaultTrainingConfiguration = 'train_detector_default.viame_csv.conf';
453+
const allowedPatterns = /\.viame_csv\.conf$/;
454+
const disallowedPatterns = /.*(_nf|\.continue)\.viame_csv\.conf$/;
455+
const allowedModelExtensions = ['.zip', '.pth', '.pt', '.py', '.weights', '.wt'];
434456
const exists = await fs.pathExists(pipelinePath);
435457
if (!exists) {
436458
throw new Error(`Path does not exist: ${pipelinePath}`);
@@ -439,9 +461,22 @@ async function getTrainingConfigs(settings: Settings): Promise<TrainingConfigs>
439461
configs = configs
440462
.filter((p) => (p.match(allowedPatterns) && !p.match(disallowedPatterns)))
441463
.sort((a, b) => (a === defaultTrainingConfiguration ? -1 : a.localeCompare(b)));
464+
// Get Model files in the pipeline directory
465+
const modelList = getFilesWithExtensions(pipelinePath, allowedModelExtensions);
466+
const models: TrainingConfigs['models'] = {};
467+
modelList.forEach((model) => {
468+
models[npath.basename(model)] = {
469+
name: npath.basename(model),
470+
type: npath.extname(model),
471+
path: model,
472+
};
473+
});
442474
return {
443-
default: configs[0],
444-
configs,
475+
training: {
476+
default: configs[0],
477+
configs,
478+
},
479+
models,
445480
};
446481
}
447482

client/platform/desktop/backend/native/viame.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,11 @@ async function train(
394394
command.push('--gt-frames-only');
395395
}
396396

397+
if (runTrainingArgs.fineTuneModel && runTrainingArgs.fineTuneModel.path) {
398+
command.push('--init-weights');
399+
command.push(runTrainingArgs.fineTuneModel.path);
400+
}
401+
397402
const job = observeChild(spawn(command.join(' '), {
398403
shell: viameConstants.shell,
399404
cwd: jobWorkDir,

client/platform/desktop/constants.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ export interface RunTraining {
171171
annotatedFramesOnly: boolean;
172172
// contents of labels.txt file
173173
labelText?: string;
174+
// fine tuning model
175+
fineTuneModel?: {
176+
name: string;
177+
type: string;
178+
path?: string;
179+
folderId?: string;
180+
};
174181
}
175182

176183
export interface ConversionArgs {

client/platform/desktop/frontend/api.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,20 @@ async function runTraining(
110110
config: string,
111111
annotatedFramesOnly: boolean,
112112
labelText?: string,
113+
fineTuneModel?: {
114+
name: string;
115+
type: string;
116+
path?: string;
117+
folderId?: string;
118+
},
113119
): Promise<DesktopJob> {
114120
const args: RunTraining = {
115121
datasetIds: folderIds,
116122
pipelineName,
117123
trainingConfig: config,
118124
annotatedFramesOnly,
119125
labelText,
126+
fineTuneModel,
120127
};
121128
return ipcRenderer.invoke('run-training', args);
122129
}

client/platform/desktop/frontend/components/MultiTrainingMenu.vue

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,14 @@ export default defineComponent({
5252
stagedItems: {} as Record<string, DatasetMeta>,
5353
trainingOutputName: '',
5454
selectedTrainingConfig: 'foo.whatever',
55+
fineTuneTraining: false,
56+
selectedFineTune: null as null | string,
5557
trainingConfigurations: {
56-
configs: [],
57-
default: '',
58+
training: {
59+
configs: [],
60+
default: '',
61+
},
62+
models: {},
5863
} as TrainingConfigs,
5964
annotatedFramesOnly: false,
6065
});
@@ -90,7 +95,16 @@ export default defineComponent({
9095
onBeforeMount(async () => {
9196
const configs = await getTrainingConfigurations();
9297
data.trainingConfigurations = configs;
93-
data.selectedTrainingConfig = configs.default;
98+
data.selectedTrainingConfig = configs.training.default;
99+
});
100+
101+
const modelNames = computed(() => {
102+
if (data.trainingConfigurations.models) {
103+
const list = Object.entries(data.trainingConfigurations.models)
104+
.map(([, value]) => value.name);
105+
return list;
106+
}
107+
return [];
94108
});
95109
96110
function toggleStaged(meta: DatasetMeta) {
@@ -172,12 +186,20 @@ export default defineComponent({
172186
}
173187
174188
async function runTrainingOnFolder() {
189+
// Get the full data for fine tuning
190+
let foundTrainingModel;
191+
if (data.fineTuneTraining) {
192+
foundTrainingModel = Object.values(data.trainingConfigurations.models)
193+
.find((item) => item.name === data.selectedFineTune);
194+
}
175195
try {
176196
await runTraining(
177197
stagedItems.value.map(({ id }) => id),
178198
data.trainingOutputName,
179199
data.selectedTrainingConfig,
180200
data.annotatedFramesOnly,
201+
undefined,
202+
foundTrainingModel,
181203
);
182204
router.push({ name: 'jobs' });
183205
} catch (err) {
@@ -204,6 +226,7 @@ export default defineComponent({
204226
nameRules,
205227
itemsPerPageOptions,
206228
clientSettings,
229+
modelNames,
207230
models: {
208231
items: trainedModels,
209232
headers: trainedHeadersTmpl.concat({
@@ -255,7 +278,10 @@ export default defineComponent({
255278
<v-card-text>
256279
Add datasets to the staging area and choose a training configuration.
257280
</v-card-text>
258-
<v-row class="mt-4 pt-0">
281+
<v-row
282+
class="mt-4 pt-0"
283+
dense
284+
>
259285
<v-col sm="5">
260286
<v-text-field
261287
v-model="data.trainingOutputName"
@@ -272,7 +298,7 @@ export default defineComponent({
272298
outlined
273299
dense
274300
label="Configuration File (Required)"
275-
:items="data.trainingConfigurations.configs"
301+
:items="data.trainingConfigurations.training.configs"
276302
:hint="data.selectedTrainingConfig"
277303
persistent-hint
278304
>
@@ -321,6 +347,20 @@ export default defineComponent({
321347
Train on ({{ staged.items.value.length }}) Datasets
322348
</v-btn>
323349
</div>
350+
<div class="d-flex flex-row mt-7">
351+
<v-checkbox
352+
v-model="data.fineTuneTraining"
353+
label="Fine Tuning"
354+
hint="Fine tune an existing model"
355+
/>
356+
<v-spacer />
357+
<v-select
358+
v-if="data.fineTuneTraining"
359+
v-model="data.selectedFineTune"
360+
:items="modelNames"
361+
label="Fine Tune Model"
362+
/>
363+
</div>
324364
</div>
325365
<div>
326366
<v-card-title class="text-h4">

client/platform/web-girder/api/rpc.service.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,14 @@ function runTraining(
2525
config: string,
2626
annotatedFramesOnly: boolean,
2727
labelText?: string,
28+
fineTuneModel?: {
29+
name: string;
30+
type: string;
31+
path?: string;
32+
folderId?: string;
33+
},
2834
) {
29-
return girderRest.post('dive_rpc/train', { folderIds, labelText }, {
35+
return girderRest.post('dive_rpc/train', { folderIds, labelText, fineTuneModel }, {
3036
params: {
3137
pipelineName, config, annotatedFramesOnly,
3238
},

client/platform/web-girder/views/RunTrainingMenu.vue

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ export default defineComponent({
3838
const trainingConfigurations = ref<TrainingConfigs | null>(null);
3939
const selectedTrainingConfig = ref<string | null>(null);
4040
const annotatedFramesOnly = ref<boolean>(false);
41+
const fineTuning = ref<boolean>(false);
42+
const selectedFineTune = ref<string>('');
4143
const {
4244
request: _runTrainingRequest,
4345
reset: dismissJobDialog,
@@ -46,10 +48,32 @@ export default defineComponent({
4648
4749
const successMessage = computed(() => `Started training on ${props.selectedDatasetIds.length} dataset(s)`);
4850
51+
const fineTuneModelList = computed(() => {
52+
const modelList: {text: string, type: 'user' | 'system', name: string}[] = [];
53+
if (trainingConfigurations.value?.models) {
54+
Object.entries(trainingConfigurations.value.models)
55+
.forEach(([, value]) => {
56+
modelList.push({
57+
text: `${value.name} - ${value.folderId ? 'User' : 'System'} Model`,
58+
type: value.folderId ? 'user' : 'system',
59+
name: value.name,
60+
});
61+
});
62+
}
63+
modelList.sort((a, b) => b.type.localeCompare(a.type));
64+
return modelList;
65+
});
66+
const selectedFineTuneObject = computed(() => {
67+
if (selectedFineTune.value !== '' && trainingConfigurations.value?.models) {
68+
return Object.values(trainingConfigurations.value.models)
69+
.find((model) => model.name === selectedFineTune.value);
70+
}
71+
return undefined;
72+
});
4973
onBeforeMount(async () => {
5074
const resp = await getTrainingConfigurations();
5175
trainingConfigurations.value = resp;
52-
selectedTrainingConfig.value = resp.default;
76+
selectedTrainingConfig.value = resp.training.default;
5377
});
5478
5579
const trainingDisabled = computed(() => props.selectedDatasetIds.length === 0);
@@ -74,13 +98,16 @@ export default defineComponent({
7498
selectedTrainingConfig.value,
7599
annotatedFramesOnly.value,
76100
labelText.value,
101+
selectedFineTuneObject.value,
77102
);
78103
}
79104
return runTraining(
80105
props.selectedDatasetIds,
81106
outputPipelineName,
82107
selectedTrainingConfig.value,
83108
annotatedFramesOnly.value,
109+
undefined,
110+
selectedFineTuneObject.value,
84111
);
85112
});
86113
menuOpen.value = false;
@@ -116,6 +143,10 @@ export default defineComponent({
116143
labelFile,
117144
clearLabelText,
118145
simplifyTrainingName,
146+
// Fine-Tuning
147+
fineTuning,
148+
fineTuneModelList,
149+
selectedFineTune,
119150
};
120151
},
121152
});
@@ -161,6 +192,7 @@ export default defineComponent({
161192
<v-card
162193
v-if="trainingConfigurations"
163194
outlined
195+
class="training-menu"
164196
>
165197
<v-card-title class="pb-1">
166198
Run Training
@@ -198,7 +230,7 @@ export default defineComponent({
198230
outlined
199231
class="my-4"
200232
label="Configuration File"
201-
:items="trainingConfigurations.configs"
233+
:items="trainingConfigurations.training.configs"
202234
:hint="selectedTrainingConfig"
203235
persistent-hint
204236
>
@@ -225,6 +257,25 @@ export default defineComponent({
225257
persistent-hint
226258
class="pt-0"
227259
/>
260+
<v-checkbox
261+
v-model="fineTuning"
262+
label="Fine Tune Model"
263+
hint="Fine Tune an existing model"
264+
persistent-hint
265+
class="pt-0"
266+
/>
267+
<v-select
268+
v-if="fineTuning"
269+
v-model="selectedFineTune"
270+
outlined
271+
class="my-4"
272+
label="Fine Tune Model"
273+
:items="fineTuneModelList"
274+
item-value="name"
275+
item-text="text"
276+
hint="Model to Fine Tune"
277+
persistent-hint
278+
/>
228279
<v-btn
229280
depressed
230281
block
@@ -248,3 +299,10 @@ export default defineComponent({
248299
/>
249300
</div>
250301
</template>
302+
303+
<style lang="css" scoped>
304+
.training-menu {
305+
max-height: 90vh;
306+
overflow-y: auto;
307+
}
308+
</style>

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ services:
125125
reservations:
126126
devices:
127127
- driver: nvidia
128-
device_ids: ['1']
128+
device_ids: ['0']
129129
capabilities: [gpu]
130130
environment:
131131
- "WORKER_WATCHING_QUEUES=training"

0 commit comments

Comments
 (0)