microsoft/vscode-languagedetection

Public

mirrored fromhttps://github.com/microsoft/vscode-languagedetectionAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d9c71921ab051a40478783ff56ccbdeb96647e50

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

lib/index.ts

145lines · modecode

1import { GraphModel, io, loadGraphModel, Rank, setBackend, tensor, Tensor } from '@tensorflow/tfjs';
2
3export interface ModelResult {
4 languageId: string;
5 confidence: number;
6}
7
8class InMemoryIOHandler implements io.IOHandler {
9
10 constructor(private readonly modelJSON: io.ModelJSON,
11 private readonly weights: ArrayBuffer) {
12 }
13
14 async load(): Promise<io.ModelArtifacts> {
15 // We do not allow both modelTopology and weightsManifest to be missing.
16 const modelTopology = this.modelJSON.modelTopology;
17 const weightsManifest = this.modelJSON.weightsManifest;
18 if (modelTopology === null && weightsManifest === null) {
19 throw new Error(
20 `The model contains neither model topology or manifest for weights.`);
21 }
22
23 return this.getModelArtifactsForJSON(
24 this.modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
25 }
26
27 private async getModelArtifactsForJSON(
28 modelJSON: io.ModelJSON,
29 loadWeights: (weightsManifest: io.WeightsManifestConfig) => Promise<[
30 /* weightSpecs */ io.WeightsManifestEntry[], /* weightData */ ArrayBuffer
31 ]>): Promise<io.ModelArtifacts> {
32 const modelArtifacts: io.ModelArtifacts = {
33 modelTopology: modelJSON.modelTopology,
34 format: modelJSON.format,
35 generatedBy: modelJSON.generatedBy,
36 convertedBy: modelJSON.convertedBy
37 };
38
39 if (modelJSON.trainingConfig !== null) {
40 modelArtifacts.trainingConfig = modelJSON.trainingConfig;
41 }
42 if (modelJSON.weightsManifest !== null) {
43 const [weightSpecs, weightData] =
44 await loadWeights(modelJSON.weightsManifest);
45 modelArtifacts.weightSpecs = weightSpecs;
46 modelArtifacts.weightData = weightData;
47 }
48 if (modelJSON.signature !== null) {
49 modelArtifacts.signature = modelJSON.signature;
50 }
51 if (modelJSON.userDefinedMetadata !== null) {
52 modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
53 }
54 if (modelJSON.modelInitializer !== null) {
55 modelArtifacts.modelInitializer = modelJSON.modelInitializer;
56 }
57
58 return modelArtifacts;
59 }
60
61 private async loadWeights(weightsManifest: io.WeightsManifestConfig): Promise<[io.WeightsManifestEntry[], ArrayBuffer]> {
62 const weightSpecs = [];
63 for (const entry of weightsManifest) {
64 weightSpecs.push(...entry.weights);
65 }
66
67 return [weightSpecs, this.weights];
68 }
69}
70
71export class ModelOperations {
72 private _model: GraphModel | undefined;
73 private _modelJson: io.ModelJSON | undefined;
74 private _weights: ArrayBuffer | undefined;
75
76 constructor(private readonly modelJSONFunc: () => Promise<any>,
77 private readonly weightsFunc: () => Promise<ArrayBuffer>) {
78 }
79
80 private async getModelJSON() {
81 if (this._modelJson) {
82 return this._modelJson;
83 }
84 this._modelJson = await this.modelJSONFunc() as io.ModelJSON;
85 return this._modelJson;
86 }
87
88 private async getWeights() {
89 if (this._weights) {
90 return this._weights;
91 }
92 this._weights = await this.weightsFunc();
93 return this._weights;
94 }
95
96 private async loadModel() {
97 if (this._model) {
98 return;
99 }
100
101 await setBackend('cpu');
102
103 const resolvedModelJSON = await this.getModelJSON();
104 const resolvedWeights = await this.getWeights();
105 this._model = await loadGraphModel(new InMemoryIOHandler(resolvedModelJSON, resolvedWeights));
106 }
107
108 public async runModel(content: string): Promise<Array<ModelResult>> {
109 if (!content) {
110 return [];
111 }
112
113 await this.loadModel();
114
115 // call out to the model
116 const predicted = await this._model!.executeAsync(tensor([content]));
117 const probabilitiesTensor: Tensor<Rank> = Array.isArray(predicted) ? predicted[0]! : predicted;
118 const languageTensor: Tensor<Rank> = Array.isArray(predicted) ? predicted[1]! : predicted;
119 const probabilities = probabilitiesTensor.dataSync() as Float32Array;
120 const langs: Array<string> = languageTensor.dataSync() as any;
121
122 const objs: Array<ModelResult> = [];
123 for (let i = 0; i < langs.length; i++) {
124 objs.push({
125 languageId: langs[i],
126 confidence: probabilities[i],
127 });
128 }
129
130 let maxIndex = 0;
131 for (let i = 0; i < probabilities.length; i++) {
132 if (probabilities[i] > probabilities[maxIndex]) {
133 maxIndex = i;
134 }
135 }
136
137 return objs.sort((a, b) => {
138 return b.confidence - a.confidence;
139 });
140 }
141
142 public dispose() {
143 this._model?.dispose();
144 }
145}
146