microsoft/vscode-languagedetection

Public

mirrored fromhttps://github.com/microsoft/vscode-languagedetectionAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
4b130f80ed16fb3ccb072fc6bb7a9b2b26ef7a9b

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

lib/index.ts

219lines · modecode

1import { Rank, tensor, Tensor, io, setBackend, env } from '@tensorflow/tfjs-core';
2import { GraphModel, loadGraphModel } from '@tensorflow/tfjs-converter';
3
4export interface ModelResult {
5 languageId: string;
6 confidence: number;
7}
8
9class InMemoryIOHandler implements io.IOHandler {
10
11 constructor(private readonly modelJSON: io.ModelJSON,
12 private readonly weights: ArrayBuffer) {
13 }
14
15 async load(): Promise<io.ModelArtifacts> {
16 // We do not allow both modelTopology and weightsManifest to be missing.
17 const modelTopology = this.modelJSON.modelTopology;
18 const weightsManifest = this.modelJSON.weightsManifest;
19 if (modelTopology === null && weightsManifest === null) {
20 throw new Error(
21 `The model contains neither model topology or manifest for weights.`);
22 }
23
24 return this.getModelArtifactsForJSON(
25 this.modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
26 }
27
28 private async getModelArtifactsForJSON(
29 modelJSON: io.ModelJSON,
30 loadWeights: (weightsManifest: io.WeightsManifestConfig) => Promise<[
31 /* weightSpecs */ io.WeightsManifestEntry[], /* weightData */ ArrayBuffer
32 ]>): Promise<io.ModelArtifacts> {
33 const modelArtifacts: io.ModelArtifacts = {
34 modelTopology: modelJSON.modelTopology,
35 format: modelJSON.format,
36 generatedBy: modelJSON.generatedBy,
37 convertedBy: modelJSON.convertedBy
38 };
39
40 if (modelJSON.trainingConfig !== null) {
41 modelArtifacts.trainingConfig = modelJSON.trainingConfig;
42 }
43 if (modelJSON.weightsManifest !== null) {
44 const [weightSpecs, weightData] =
45 await loadWeights(modelJSON.weightsManifest);
46 modelArtifacts.weightSpecs = weightSpecs;
47 modelArtifacts.weightData = weightData;
48 }
49 if (modelJSON.signature !== null) {
50 modelArtifacts.signature = modelJSON.signature;
51 }
52 if (modelJSON.userDefinedMetadata !== null) {
53 modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
54 }
55 if (modelJSON.modelInitializer !== null) {
56 modelArtifacts.modelInitializer = modelJSON.modelInitializer;
57 }
58
59 return modelArtifacts;
60 }
61
62 private async loadWeights(weightsManifest: io.WeightsManifestConfig): Promise<[io.WeightsManifestEntry[], ArrayBuffer]> {
63 const weightSpecs = [];
64 for (const entry of weightsManifest) {
65 weightSpecs.push(...entry.weights);
66 }
67
68 return [weightSpecs, this.weights];
69 }
70}
71
72export interface ModelOperationsOptions {
73 modelJsonLoaderFunc?: () => Promise<{ [key:string]: any }>;
74 weightsLoaderFunc?: () => Promise<ArrayBuffer>;
75 minContentSize?: number;
76 maxContentSize?: number;
77 normalizeNewline?: boolean;
78}
79
80export class ModelOperations {
81 private static DEFAULT_MAX_CONTENT_SIZE = 100000;
82 private static DEFAULT_MIN_CONTENT_SIZE = 20;
83
84 private static NODE_MODEL_JSON_FUNC: () => Promise<{ [key:string]: any }> = async () => {
85 const fs = await import('fs');
86 const path = await import('path');
87
88 return new Promise<any>((resolve, reject) => {
89 fs.readFile(path.join(__dirname, '..', '..', 'model', 'model.json'), (err, data) => {
90 if(err) {
91 reject(err);
92 return;
93 }
94 resolve(JSON.parse(data.toString()));
95 });
96 });
97 }
98
99 private static NODE_WEIGHTS_FUNC: () => Promise<ArrayBuffer> = async () => {
100 const fs = await import('fs');
101 const path = await import('path');
102
103 return new Promise<ArrayBuffer>((resolve, reject) => {
104 fs.readFile(path.join(__dirname, '..', '..', 'model', 'group1-shard1of1.bin'), (err, data) => {
105 if(err) {
106 reject(err);
107 return;
108 }
109 resolve(data.buffer);
110 });
111 });
112 }
113
114 private _model: GraphModel | undefined;
115 private _modelJson: io.ModelJSON | undefined;
116 private _weights: ArrayBuffer | undefined;
117 private readonly _minContentSize: number;
118 private readonly _maxContentSize: number;
119 private readonly _modelJsonLoaderFunc: () => Promise<{ [key:string]: any }>;
120 private readonly _weightsLoaderFunc: () => Promise<ArrayBuffer>;
121 private readonly _normalizeNewline: boolean;
122
123 constructor(modelOptions?: ModelOperationsOptions) {
124 this._modelJsonLoaderFunc = modelOptions?.modelJsonLoaderFunc ?? ModelOperations.NODE_MODEL_JSON_FUNC;
125 this._weightsLoaderFunc = modelOptions?.weightsLoaderFunc ?? ModelOperations.NODE_WEIGHTS_FUNC;
126 this._minContentSize = modelOptions?.minContentSize ?? ModelOperations.DEFAULT_MIN_CONTENT_SIZE;
127 this._maxContentSize = modelOptions?.maxContentSize ?? ModelOperations.DEFAULT_MAX_CONTENT_SIZE;
128 this._normalizeNewline = modelOptions?.normalizeNewline ?? true;
129 }
130
131 private async getModelJSON(): Promise<io.ModelJSON> {
132 if (this._modelJson) {
133 return this._modelJson;
134 }
135
136 // TODO: validate model.json
137 this._modelJson = await this._modelJsonLoaderFunc() as io.ModelJSON;
138 return this._modelJson;
139 }
140
141 private async getWeights() {
142 if (this._weights) {
143 return this._weights;
144 }
145
146 // TODO: validate weights
147 this._weights = await this._weightsLoaderFunc();
148 return this._weights;
149 }
150
151 private async loadModel() {
152 if (this._model) {
153 return;
154 }
155
156 // These 2 env set's just suppress some warnings that get logged that
157 // are not applicable for this use case.
158 const tfEnv = env();
159 tfEnv.set('IS_NODE', false);
160 tfEnv.set('PROD', true);
161
162 await import('@tensorflow/tfjs-backend-cpu');
163 if(!(await setBackend('cpu'))) {
164 throw new Error('Unable to set backend to CPU.');
165 }
166
167 const resolvedModelJSON = await this.getModelJSON();
168 const resolvedWeights = await this.getWeights();
169 this._model = await loadGraphModel(new InMemoryIOHandler(resolvedModelJSON, resolvedWeights));
170 }
171
172 public async runModel(content: string): Promise<Array<ModelResult>> {
173 if (!content || content.length < this._minContentSize) {
174 return [];
175 }
176
177 await this.loadModel();
178
179 // larger files cause a "RangeError: Maximum call stack size exceeded" in tfjs.
180 // So grab the first X characters as that should be good enough for guessing.
181 if (content.length >= this._maxContentSize) {
182 content = content.substring(0, this._maxContentSize);
183 }
184
185 if (this._normalizeNewline) {
186 content = content.replace(/\r\n/g, '\n');
187 }
188
189 // call out to the model
190 const predicted = await this._model!.executeAsync(tensor([content]));
191 const probabilitiesTensor: Tensor<Rank> = Array.isArray(predicted) ? predicted[0]! : predicted;
192 const languageTensor: Tensor<Rank> = Array.isArray(predicted) ? predicted[1]! : predicted;
193 const probabilities = probabilitiesTensor.dataSync() as Float32Array;
194 const langs: Array<string> = languageTensor.dataSync() as any;
195
196 const objs: Array<ModelResult> = [];
197 for (let i = 0; i < langs.length; i++) {
198 objs.push({
199 languageId: langs[i],
200 confidence: probabilities[i],
201 });
202 }
203
204 let maxIndex = 0;
205 for (let i = 0; i < probabilities.length; i++) {
206 if (probabilities[i] > probabilities[maxIndex]) {
207 maxIndex = i;
208 }
209 }
210
211 return objs.sort((a, b) => {
212 return b.confidence - a.confidence;
213 });
214 }
215
216 public dispose() {
217 this._model?.dispose();
218 }
219}
220